Skip to content

Commit 2a2fa30

Browse files
committed
Add trtllmgen FP4 MOE throughput kernel
Signed-off-by: jiahanc <[email protected]> update some work Signed-off-by: jiahanc <[email protected]> update some work Signed-off-by: jiahanc <[email protected]> update some more files Signed-off-by: jiahanc <[email protected]>
1 parent a4ddf26 commit 2a2fa30

File tree

14 files changed

+342
-158
lines changed

14 files changed

+342
-158
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,17 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
100100
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
101101
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
102102
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
103-
tileSize == mOptions.tileSize &&
104-
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
105-
options.mLayoutA == mOptions.weightLayout) {
103+
tileSize == mOptions.tileSize) {
104+
auto sm = configs[i].mSm;
105+
if (sm != SmVersion::Sm100f) {
106+
int smVersion = tensorrt_llm::common::getSMVersion();
107+
if (smVersion == 100 && sm != SmVersion::Sm100a) {
108+
continue;
109+
} else if (smVersion == 103 && sm != SmVersion::Sm103a) {
110+
continue;
111+
}
112+
}
113+
106114
if (options.mFusedAct) {
107115
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
108116
continue;
@@ -161,6 +169,7 @@ void TrtllmGenBatchedGemmRunner::run(
161169
auto const configs = bmm.getBatchedGemmConfigs();
162170

163171
auto const& config = configs[configIndex];
172+
std::cout << "config function name: " << config.mFunctionName << std::endl;
164173

165174
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
166175
if (!mOptions.staticBatch) {
@@ -367,6 +376,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
367376

368377
return false;
369378
};
379+
370380
// Sort configs by options.
371381
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
372382
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc);
@@ -381,6 +391,13 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
381391
auto const& config = configs[configIndex];
382392
auto isValidConfig = bmm.isValidConfig(config, gemmData);
383393
if (isValidConfig) {
394+
// if (static_cast<int32_t>(config.mOptions.mLayoutA) == 0 ){
395+
// std::cout << "config.mLayoutA: " << static_cast<int32_t>(config.mOptions.mLayoutA) <<
396+
// std::endl; std::cout << "config.mLayoutB: " <<
397+
// static_cast<int32_t>(config.mOptions.mLayoutB) << std::endl; std::cout <<
398+
// "config.mFunctionName: " << config.mFunctionName << std::endl;
399+
// validConfigIndices.push_back(configIndex);
400+
// }
384401
validConfigIndices.push_back(configIndex);
385402
}
386403
}

flashinfer/artifacts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_available_cubin_files(
7676
class ArtifactPath:
7777
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
7878
TRTLLM_GEN_BMM: str = (
79-
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
79+
"696906bd3985f84662799054f377b4b47a1907d3/batched_gemm-074aec4-3df1e6c"
8080
)
8181
TRTLLM_GEN_GEMM: str = (
8282
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
@@ -90,9 +90,7 @@ class MetaInfoHash:
9090
TRTLLM_GEN_FMHA: str = (
9191
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
9292
)
93-
TRTLLM_GEN_BMM: str = (
94-
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
95-
)
93+
TRTLLM_GEN_BMM: str = "696906bd3985f84662799054f377b4b47a1907d3"
9694
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9795
TRTLLM_GEN_GEMM: str = (
9896
"0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba"

flashinfer/fused_moe/core.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,9 @@ def __init__(
894894
self.gated_act_type = gated_act_type
895895
self.tile_tokens_dim = tile_tokens_dim
896896

897-
def get_tile_tokens_dim(self, num_tokens: int, top_k: int):
897+
def get_tile_tokens_dim(
898+
self, num_tokens: int, top_k: int, max_tile_tokens_dim: int = 128
899+
):
898900
# Factor to account for the imbalance of the experts.
899901
# factor equals to the
900902
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
@@ -910,10 +912,10 @@ def get_tile_tokens_dim(self, num_tokens: int, top_k: int):
910912
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
911913
# And pad the number to the next power of 2.
912914
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
913-
# Cap to 8-64 tokens per CTA tile
914-
# as it's the range supported by the kernel.
915-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
916-
915+
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
916+
tile_tokens_dim = 192
917+
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
918+
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
917919
return tile_tokens_dim
918920

919921
def get_valid_tactics(
@@ -931,7 +933,7 @@ def get_valid_tactics(
931933
) = inputs
932934
num_tokens = routing_logits.shape[0]
933935
tile_tokens_dim = (
934-
self.get_tile_tokens_dim(num_tokens, self.top_k)
936+
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
935937
if self.tile_tokens_dim is None
936938
else self.tile_tokens_dim
937939
)
@@ -975,7 +977,7 @@ def forward(
975977
) = inputs
976978
num_tokens = routing_logits.shape[0]
977979
tile_tokens_dim = (
978-
self.get_tile_tokens_dim(num_tokens, self.top_k)
980+
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
979981
if self.tile_tokens_dim is None
980982
else self.tile_tokens_dim
981983
)

flashinfer/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,18 @@ def next_positive_power_of_2(x: int) -> int:
113113
return n + 1
114114

115115

116-
def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) -> int:
116+
def calculate_tile_tokens_dim(
117+
num_tokens: int, num_experts: int, top_k: int, max_tile_tokens_dim: int = 128
118+
) -> int:
117119
# Guess tokens per expert assuming perfect expert distribution first.
118120
num_tokens_per_expert = num_tokens * top_k // num_experts
119121

120122
# And pad the number to the next power of 2.
121123
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
122-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
123-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
124+
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
125+
tile_tokens_dim = 192
126+
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
127+
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
124128

125129
return tile_tokens_dim
126130

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,19 @@ class BatchedGemmInterface {
506506
throw std::invalid_argument("Invalid combination of options");
507507
}
508508

509-
int32_t const numCtasTile =
509+
if (batchM) {
510+
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX);
511+
} else {
512+
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY);
513+
}
514+
515+
int32_t numCtasTile =
510516
batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM);
517+
if (batchM) {
518+
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY);
519+
} else {
520+
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX);
521+
}
511522
int32_t const numCtasInner = options.mNumSlicesForSplitK;
512523
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
513524
}

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 63 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,43 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
7676
// FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
7777
BatchedGemmOptions(
7878
gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
79-
int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB,
80-
tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
81-
bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps,
82-
int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA,
83-
bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA,
84-
bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
85-
gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB,
86-
int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
87-
int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma,
88-
int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
89-
bool outputDebugTensors, bool patchF2fp, std::optional<int32_t> sfBlockSizeA,
90-
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
91-
int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
92-
gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule,
93-
bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
94-
bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps,
95-
bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType,
96-
bool clampBeforeAct, std::vector<int> batchedM, std::vector<int> batchedN,
97-
BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
98-
bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
99-
int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
79+
int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc,
80+
tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA,
81+
tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
82+
bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM,
83+
int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
84+
bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB,
85+
bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits,
86+
gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind,
87+
int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps,
88+
int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
89+
int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK,
90+
int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
91+
int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
92+
std::optional<int32_t> sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB,
93+
tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK,
94+
int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
95+
bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule,
96+
bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore,
97+
bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
98+
gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int> batchedM,
99+
std::vector<int> batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
100+
int numTokens, RouteImpl routeImpl, std::optional<RouteImpl> routeSfsImpl,
101+
bool gridWaitForPrimaryRouting, bool fusedAct, bool useTmaOobOpt)
100102
: gemmGatedAct::GemmGatedActOptions(
101103
gemm::GemmOptions(
102-
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc,
103-
dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit,
104-
enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
105-
epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB,
106-
gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
107-
hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK,
108-
mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK,
109-
numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile,
110-
numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB,
111-
sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
104+
allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ,
105+
ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB,
106+
enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps,
107+
epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
108+
gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
109+
gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
110+
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps,
111+
numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp,
112+
numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
113+
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
114+
outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC,
115+
sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler,
112116
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
113117
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB,
114118
useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
@@ -121,11 +125,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
121125
mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting),
122126
mIsStaticBatch(isStaticBatch),
123127
mNumBatches(numBatches),
124-
mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp),
125-
mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp),
126-
mNumRegsCastAWarps(numRegsCastAWarps),
127128
mNumTokens(numTokens),
128129
mRouteImpl(routeImpl),
130+
mRouteSfsImpl(routeSfsImpl),
129131
mUseTmaOobOpt(useTmaOobOpt) {}
130132

131133
// Batched M-dimensions of GEMM.
@@ -143,16 +145,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions {
143145
bool mIsStaticBatch{true};
144146
// Number of Gemm batches.
145147
int mNumBatches;
146-
// Number of registers per thread for non-epilogue warps
147-
int mNumRegsPerThreadNonEpilogueWarp{0};
148-
// Number of registers per thread for epilogue warps
149-
int mNumRegsPerThreadEpilogueWarp{0};
150-
// Number of registers for the cast A warps.
151-
int mNumRegsCastAWarps{0};
152148
// Total number of tokens.
153149
int mNumTokens{32};
154150
// Whether load the input tokens and do routing.
155151
RouteImpl mRouteImpl{RouteImpl::NoRoute};
152+
// Routing logic for scaling factors. If not specified, mRouteImpl is used.
153+
std::optional<RouteImpl> mRouteSfsImpl{std::nullopt};
156154
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
157155
// BatchedGemm/KernelParamsDecl.h.
158156
bool mUseTmaOobOpt{false};
@@ -235,6 +233,18 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
235233
"E2m1 is not supported with DeepSeek FP8");
236234
}
237235

236+
if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl) {
237+
TLLM_CHECK_ERROR(
238+
options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
239+
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma");
240+
} else if (!options.mRouteSfsImpl.has_value()) {
241+
if (updateOptions) {
242+
options.mRouteSfsImpl = options.mRouteImpl;
243+
} else {
244+
TLLM_LOG_ERROR("RouteSfsImpl must be specified");
245+
return false;
246+
}
247+
}
238248
if (batchM) {
239249
if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) {
240250
TLLM_CHECK_ERROR(doesRouteImplUseNoRoute(options.mRouteImpl),
@@ -269,18 +279,20 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
269279
}
270280
}
271281

272-
if (doesRouteImplUseTma(options.mRouteImpl)) {
282+
if (doesRouteImplUseTma(options.mRouteSfsImpl.value())) {
273283
TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N.");
274284

275285
if (tg::mmaKindIsBlockFmt(options.mMmaKind)) {
276286
auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB;
277-
TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0,
278-
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
279287
TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0,
280288
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
281289
}
282290
}
283291

292+
if (options.mClusterDimX > 1) {
293+
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");
294+
}
295+
284296
if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl)) {
285297
TLLM_CHECK_ERROR(options.mSfLayoutA == tg::SfLayout::R128c4,
286298
"options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed");
@@ -301,6 +313,11 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
301313
TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, "K must be a multiple of TileK");
302314
}
303315

316+
if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute) {
317+
TLLM_CHECK_ERROR(false,
318+
"2CTA BatchedGemm does not support routing along M dimension. To support it, "
319+
"change the input routing data layout to be padded to clusterDimX size.");
320+
}
304321
return isValid;
305322
}
306323

@@ -323,6 +340,7 @@ struct BatchedGemmConfig {
323340
char const* mHash{nullptr};
324341
#else
325342
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
343+
int32_t mInstanceIdx{0};
326344
#endif
327345

328346
BatchedGemmOptions mOptions;
@@ -343,13 +361,10 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) {
343361
ss << "mNumTokens=" << options.mNumTokens << "," << std::endl;
344362
ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast<int32_t>(options.mRouteImpl) << "),"
345363
<< std::endl;
364+
ss << "mRouteSfsImpl={batchedGemm::RouteImpl("
365+
<< static_cast<int32_t>(options.mRouteSfsImpl.value()) << ")}," << std::endl;
346366
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
347367
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
348-
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << ","
349-
<< std::endl;
350-
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
351-
<< std::endl;
352-
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
353368
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
354369
return ss.str();
355370
}

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ enum class TileScheduler {
9797

9898
////////////////////////////////////////////////////////////////////////////////////////////////////
9999

100+
enum class CtaSwizzleType : uint32_t {
101+
// Rasterize CTAs along the M dimension.
102+
RasterizeAlongM = 0,
103+
// Rasterize CTAs along the N dimension.
104+
RasterizeAlongN,
105+
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2.
106+
ZigZagAlongM2,
107+
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2.
108+
ZigZagAlongN2,
109+
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4.
110+
ZigZagAlongM4,
111+
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4.
112+
ZigZagAlongN4,
113+
};
114+
115+
////////////////////////////////////////////////////////////////////////////////////////////////////
116+
100117
// Helper functions to check the SplitK type.
101118

102119
#define SPLIT_K_FUNCTION(Mode) \

0 commit comments

Comments
 (0)