Skip to content

Commit d98f090

Browse files
committed
fix wip
1 parent 3e8f69a commit d98f090

File tree

5 files changed

+102
-36
lines changed

5 files changed

+102
-36
lines changed

flashinfer/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,15 +774,15 @@ def cutlass_fused_moe(
774774

775775

776776
def trtllm_gen_fused_moe_sm100_module() -> JitSpec:
777-
hash = "6b93c394210c89dccef13833c89797f1b8f8aefb"
778-
tllm_gen_commit = "ce8ce46"
777+
hash = "5e0cff4583554d182ae3fee461ff87b481ff3464"
778+
tllm_gen_commit = "573cd5a"
779779
tllm_gen_config_hash = "2dc78d9"
780780
include_path = (
781781
f"{hash}/batched_gemm-{tllm_gen_commit}-{tllm_gen_config_hash}/include"
782782
)
783783
metainfo = get_cubin(
784784
f"{include_path}/flashinferMetaInfo",
785-
"b24fd5e7ae6b20e903c866ecb1d4a68f238301ba9b76df6a536056f2059a0d56",
785+
"a13e1ca232f60ca9eefb3298153aba03ccab6916748cf7e68b731d8dc4e9ccbc",
786786
".h",
787787
)
788788
assert metainfo, "KernelMetaInfo.h not found"

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,48 @@ struct BatchedGemmData {
243243
// Shape is [B].
244244
float const* mPtrScaleGate{nullptr};
245245

246+
// The clamp limit for the accumulator before applying the activation.
247+
// Shape is [B].
248+
// Clamp is INF if nullptr.
249+
// When the input is FP8 or NVFP4, the clamp has to be scaled by limit' = limit / dequantAb.
250+
// If applied on SwiGlu, it will be:
251+
//
252+
// x_glu = x_glu.clamp(min=None, max=limit)
253+
// x_linear = x_linear.clamp(min=-limit, max=limit)
254+
//
255+
// The given clamp limit applies to the dequantized values, so the order of operations would
256+
// look something like this:
257+
//
258+
// x0 = x0 * dqAb
259+
// x0 = clamp(x0, none, limit)
260+
// x0 = x0 * sigmoid(alpha * x0)
261+
// x1 = dqAb * x1
262+
// x1 = clamp(x1, -limit, limit)
263+
// out = qC * (x1 + beta) * x0
264+
//
265+
// Given that the dqAb and qC are combined into scaleC, we can bring the dqAb into the clamp
266+
// limit and apply the clamping prior to dequantization:
267+
//
268+
// x0 = clamp(x0, none, limit / dqAb)
269+
// x0 = x0 * dqAb
270+
// x0 = x0 * sigmoid(alpha * x0)
271+
// x1 = clamp(x1, -limit / dqAb, limit / dqAb)
272+
// scaleC = dqAb * qC
273+
// beta' = beta / dqAb
274+
// out = scaleC * (x1 + beta') * x0
275+
//
276+
// Note this assumes that scaleAb == scaleGate which is true in TRT-LLM MoE use-case
277+
//
278+
float const* mPtrClampLimit{nullptr};
279+
246280
// The alpha and beta for SwiGlu.
247281
// gatedActivation <- (x0 + beta) * activation(x1, alpha)
248282
// Shape is [B].
249283
// Alpha is 1.f if nullptr.
250284
// Beta is 0.f if nullptr.
285+
// The formula:
286+
//
287+
// out_glu = x_glu * torch.sigmoid(alpha * x_glu) + (x_linear + beta)
251288
float const* mPtrSwiGluAlpha{nullptr};
252289
float const* mPtrSwiGluBeta{nullptr};
253290

@@ -630,9 +667,10 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
630667
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
631668
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
632669
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
633-
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrSwiGluAlpha,
634-
batchedGemmData.mInputBuffers.mPtrSwiGluBeta, batchedGemmData.mInputBuffers.mPtrRouteMap,
635-
dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
670+
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit,
671+
batchedGemmData.mInputBuffers.mPtrSwiGluAlpha, batchedGemmData.mInputBuffers.mPtrSwiGluBeta,
672+
batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax, dPtrRowMaxBars,
673+
batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
636674
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
637675
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
638676
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, maxNumCtasInBatchDim);

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
9090
bool usePerTokenSfB, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
9191
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
9292
int32_t sfReshapeFactor, gemm::TileScheduler tileScheduler, gemmGatedAct::ActType actType,
93-
std::vector<int> batchedM, std::vector<int> batchedN, BatchMode batchMode, int numBatches,
94-
bool isStaticBatch, int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting,
95-
bool fusedAct, int numRegsPerThreadNonEpilogueWarp, int numRegsPerThreadEpilogueWarp,
96-
int numRegsCastAWarps)
93+
bool clampBeforeAct, std::vector<int> batchedM, std::vector<int> batchedN,
94+
BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
95+
bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
96+
int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
9797
: gemmGatedAct::GemmGatedActOptions(
9898
gemm::GemmOptions(
9999
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc,
@@ -109,48 +109,49 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
109109
useCustomMmaSchedule, useHoistTryWaitForCustomMmaSchedule, useDeepSeekFp8,
110110
usePerTokenSfA, usePerTokenSfB, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
111111
sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, tileScheduler),
112-
actType),
112+
actType, clampBeforeAct),
113113
mBatchedM(batchedM),
114114
mBatchedN(batchedN),
115115
mBatchMode(BatchMode(batchMode)),
116-
mNumBatches(numBatches),
117-
mIsStaticBatch(isStaticBatch),
118-
mNumTokens(numTokens),
119-
mRouteImpl(routeImpl),
120-
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
121116
mFusedAct(fusedAct),
117+
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
118+
mIsStaticBatch(isStaticBatch),
119+
mNumBatches(numBatches),
122120
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
123121
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
124-
mNumRegsCastAWarps(numRegsCastAWarps) {}
122+
mNumRegsCastAWarps(numRegsCastAWarps),
123+
mNumTokens(numTokens),
124+
mRouteImpl(routeImpl),
125+
mUseTmaOobOpt(useTmaOobOpt) {}
125126

126127
// Batched M-dimensions of GEMM.
127128
std::vector<int> mBatchedM;
128129
// Batched N-dimensions of GEMM.
129130
std::vector<int> mBatchedN;
130131
// Whether batching M or N.
131132
BatchMode mBatchMode{BatchMode::BatchM};
132-
// Number of Gemm batches.
133-
int mNumBatches;
134-
135-
// Whether the batch size is static (i.e. known at kernel launch time).
136-
bool mIsStaticBatch{true};
137-
// Total number of tokens.
138-
int mNumTokens{32};
139-
// Whether load the input tokens and do routing.
140-
RouteImpl mRouteImpl{RouteImpl::NoRoute};
133+
// Whether to perform a fused gated activation.
134+
bool mFusedAct{false};
141135
// Whether the loads that load from ptrRouteMap, ptrTotalNumPaddedTokens,
142136
// ptrCtaIdxXyToBatchIdx, etc.. should wait on a grid dependency.
143137
bool mGridWaitForPrimaryRouting{true};
144-
145-
// Whether to perform a fused gated activation.
146-
bool mFusedAct{false};
147-
138+
// Whether the batch size is static (i.e. known at kernel launch time).
139+
bool mIsStaticBatch{true};
140+
// Number of Gemm batches.
141+
int mNumBatches;
148142
// Number of registers per thread for non-epilogue warps
149143
int mNumRegsPerThreadNonEpilogueWarp{0};
150144
// Number of registers per thread for epilogue warps
151145
int mNumRegsPerThreadEpilogueWarp{0};
152146
// Number of registers for the cast A warps.
153147
int mNumRegsCastAWarps{0};
148+
// Total number of tokens.
149+
int mNumTokens{32};
150+
// Whether load the input tokens and do routing.
151+
RouteImpl mRouteImpl{RouteImpl::NoRoute};
152+
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
153+
// BatchedGemm/KernelParamsDecl.h.
154+
bool mUseTmaOobOpt{false};
154155
};
155156

156157
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -159,6 +160,16 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
159160
bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell,
160161
bool updateOptions = true) {
161162
bool isValid = true;
163+
if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) {
164+
if (updateOptions) {
165+
// Since any routing (mRouteAct != NoRoute) requires mUseTwoTmaLoadWarps == true.
166+
// Single TMA load warp is not the target use case for OOB optimization.
167+
options.mUseTmaOobOpt = false;
168+
} else {
169+
TLLM_CHECK_ERROR(false, "TMA OOB optimization requires two TMA load warps.");
170+
return false;
171+
}
172+
}
162173
if (options.mFusedAct) {
163174
// ensure that we check the fused options as well
164175
isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, isBlackwell, updateOptions);
@@ -302,6 +313,8 @@ struct BatchedGemmConfig {
302313
// defined. In this case, the cubins will be loaded from the provided data and function name.
303314
// Otherwise, the kernel will be loaded from the CudaRunner.
304315
#ifdef TLLM_GEN_EXPORT_INTERFACE
316+
uint8_t const* mData{nullptr};
317+
uint32_t const mSize{0};
305318
uint32_t const mSharedMemSize{0};
306319
char const* mFunctionName{nullptr};
307320
uint32_t const mNumThreadsPerCTA{0};
@@ -334,7 +347,8 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
334347
<< std::endl;
335348
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
336349
<< std::endl;
337-
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << std::endl;
350+
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
351+
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
338352
return ss.str();
339353
}
340354

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,13 @@ inline std::string getActTypeName(ActType type) {
9191

9292
struct GemmGatedActOptions : public gemm::GemmOptions {
9393
GemmGatedActOptions() = default;
94-
GemmGatedActOptions(gemm::GemmOptions options, ActType actType)
95-
: gemm::GemmOptions(options), mActType(actType) {}
94+
GemmGatedActOptions(gemm::GemmOptions options, ActType actType, bool clampBeforeAct)
95+
: gemm::GemmOptions(options), mActType(actType), mClampBeforeAct(clampBeforeAct) {}
9696

9797
// Type of the gated activation.
9898
ActType mActType{ActType::SwiGlu};
99+
// Clamp the dequantized values to the range [-limit, limit].
100+
bool mClampBeforeAct{false};
99101
};
100102

101103
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -156,8 +158,9 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions&
156158
inline std::string dumpOptions(GemmGatedActOptions const& options) {
157159
std::stringstream ss;
158160
ss << gemm::dumpOptions(options) << ", ";
159-
ss << "mActType=" << "gemmGatedAct::ActType(" << static_cast<int32_t>(options.mActType) << ")"
161+
ss << "mActType=" << "gemmGatedAct::ActType(" << static_cast<int32_t>(options.mActType) << "),"
160162
<< std::endl;
163+
ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl;
161164
return ss.str();
162165
}
163166

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,15 @@ struct KernelParams {
212212
// Shape is [B]. One scaling factor per tensor in batch.
213213
float const* ptrScaleGate{nullptr};
214214

215+
// The clamp limit before the activation.
216+
// Shape is [1].
217+
// Clamp is INF if nullptr.
218+
// If applied on SwiGlu, it will be:
219+
//
220+
// x_glu = x_glu.clamp(min=None, max=limit)
221+
// x_linear = x_linear.clamp(min=-limit, max=limit)
222+
float const* ptrClampLimit{nullptr};
223+
215224
// The alpha and beta for SwiGlu.
216225
// Shape is [B]. One alpha and one beta per tensor in batch.
217226
// Alpha is 1.f if nullptr.
@@ -695,8 +704,8 @@ struct KernelParams {
695704
GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB,
696705
void* ptrC, void const* dSfA, void const* dSfB, void const* ptrPerTokenSfA,
697706
void const* ptrPerTokenSfB, void const* ptrBias, void* dSfC, float const* ptrScaleC,
698-
float const* ptrScaleGate, float const* ptrSwiGluAlpha, float const* ptrSwiGluBeta,
699-
int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars,
707+
float const* ptrScaleGate, float const* ptrClampLimit, float const* ptrSwiGluAlpha,
708+
float const* ptrSwiGluBeta, int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars,
700709
int32_t const* ptrNumNonExitingCtas = nullptr,
701710
int32_t const* ptrTotalNumPaddedTokens = nullptr,
702711
int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, int32_t const* ptrCtaIdxXyToMnLimit = nullptr,
@@ -713,6 +722,8 @@ struct KernelParams {
713722
params.ptrScaleC = ptrScaleC;
714723
params.ptrScaleGate = ptrScaleGate;
715724

725+
params.ptrClampLimit = ptrClampLimit;
726+
716727
params.ptrSwiGluAlpha = ptrSwiGluAlpha;
717728
params.ptrSwiGluBeta = ptrSwiGluBeta;
718729

0 commit comments

Comments
 (0)