@@ -86,34 +86,36 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
8686
8787 // FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
8888 BatchedGemmOptions (gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
89- int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC,
90- tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
91- bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
92- bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit,
93- bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
94- gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK,
95- tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK,
96- int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
97- int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
98- std::optional<int32_t > sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
99- int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
100- gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
101- bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA,
102- bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
103- gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int > batchedM, std::vector<int > batchedN,
104- BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
105- bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
106- int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
89+ int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA,
90+ tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
91+ bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits,
92+ int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
93+ bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
94+ bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA,
95+ gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
96+ int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
97+ int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages,
98+ int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
99+ bool outputDebugTensors, bool patchF2fp, std::optional<int32_t > sfBlockSizeA, tg::SfLayout sfLayoutA,
100+ tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK,
101+ int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
102+ bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
103+ bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
104+ bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, bool clampBeforeAct,
105+ std::vector<int > batchedM, std::vector<int > batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
106+ int numTokens, RouteImpl routeImpl, std::optional<RouteImpl> routeSfsImpl, bool gridWaitForPrimaryRouting,
107+ bool fusedAct, bool useTmaOobOpt)
107108 : gemmGatedAct::GemmGatedActOptions(
108- gemm::GemmOptions (allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, dtypeA,
109- dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs,
110- epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
111- gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
112- hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM,
113- mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
114- numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
115- sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN,
116- tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
109+ gemm::GemmOptions (allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType,
110+ dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit,
111+ enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN,
112+ gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
113+ gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m,
114+ mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps, numRegsCopySfLdsSttm,
115+ numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
116+ numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
117+ outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
118+ splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
117119 useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore,
118120 useTwoTmaLoadWarps, useTwoMmaWarps, useUnrollLoop2xForMma, worldSize),
119121 actType, clampBeforeAct)
@@ -124,11 +126,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
124126 , mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
125127 , mIsStaticBatch(isStaticBatch)
126128 , mNumBatches(numBatches)
127- , mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp)
128- , mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp)
129- , mNumRegsCastAWarps(numRegsCastAWarps)
130129 , mNumTokens(numTokens)
131130 , mRouteImpl(routeImpl)
131+ , mRouteSfsImpl(routeSfsImpl)
132132 , mUseTmaOobOpt(useTmaOobOpt)
133133 {
134134 }
@@ -148,16 +148,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
148148 bool mIsStaticBatch {true };
149149 // Number of Gemm batches.
150150 int mNumBatches ;
151- // Number of registers per thread for non-epilogue warps
152- int mNumRegsPerThreadNonEpilogueWarp {0 };
153- // Number of registers per thread for epilogue warps
154- int mNumRegsPerThreadEpilogueWarp {0 };
155- // Number of registers for the cast A warps.
156- int mNumRegsCastAWarps {0 };
157151 // Total number of tokens.
158152 int mNumTokens {32 };
159153 // Whether load the input tokens and do routing.
160154 RouteImpl mRouteImpl {RouteImpl::NoRoute};
155+ // Routing logic for scaling factors. If not specified, mRouteImpl is used.
156+ std::optional<RouteImpl> mRouteSfsImpl {std::nullopt };
161157 // Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
162158 // BatchedGemm/KernelParamsDecl.h.
163159 bool mUseTmaOobOpt {false };
@@ -255,6 +251,24 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
255251 " E2m1 is not supported with DeepSeek FP8" );
256252 }
257253
254+ if (options.mRouteSfsImpl .has_value () && options.mRouteSfsImpl .value () != options.mRouteImpl )
255+ {
256+ TLLM_CHECK_ERROR (options.mRouteSfsImpl .value () == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
257+ " RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma" );
258+ }
259+ else if (!options.mRouteSfsImpl .has_value ())
260+ {
261+ if (updateOptions)
262+ {
263+ options.mRouteSfsImpl = options.mRouteImpl ;
264+ }
265+ else
266+ {
267+ TLLM_LOG_ERROR (" RouteSfsImpl must be specified" );
268+ return false ;
269+ }
270+ }
271+
258272 if (batchM)
259273 {
260274 if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4)
@@ -299,20 +313,23 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
299313 }
300314 }
301315
302- if (doesRouteImplUseTma (options.mRouteImpl ))
316+ if (doesRouteImplUseTma (options.mRouteSfsImpl . value () ))
303317 {
304318 TLLM_CHECK_ERROR (!batchM, " UTMALDG.GATHER4 only supported for batch N." );
305319
306320 if (tg::mmaKindIsBlockFmt (options.mMmaKind ))
307321 {
308322 auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB ;
309- TLLM_CHECK_ERROR (options.mTileK % tg::dtypeNumEltsPerSf (dtypeRoute) == 0 ,
310- " tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
311323 TLLM_CHECK_ERROR (options.mTileK % (tg::dtypeNumEltsPerSf (dtypeRoute) * 16 ) == 0 ,
312324 " tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
313325 }
314326 }
315327
328+ if (options.mClusterDimX > 1 )
329+ {
330+ TLLM_CHECK_ERROR (!batchM, " 2CTA Gemm currently only supports batch N." );
331+ }
332+
316333 if (!batchM || doesRouteImplUseNoRoute (options.mRouteImpl ))
317334 {
318335 TLLM_CHECK_ERROR (options.mSfLayoutA == tg::SfLayout::R128c4,
@@ -336,6 +353,13 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
336353 TLLM_CHECK_ERROR (options.mK % options.mTileK == 0 , " K must be a multiple of TileK" );
337354 }
338355
356+ if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute)
357+ {
358+ TLLM_CHECK_ERROR (false ,
359+ " 2CTA BatchedGemm does not support routing along M dimension. To support it, "
360+ " change the input routing data layout to be padded to clusterDimX size." );
361+ }
362+
339363 return isValid;
340364}
341365
@@ -359,6 +383,7 @@ struct BatchedGemmConfig
359383 char const * mHash {nullptr };
360384#else
361385 trtllm::gen::CudaRunner* mCudaRunner {nullptr };
386+ int32_t mInstanceIdx {0 };
362387#endif
363388
364389 BatchedGemmOptions mOptions ;
@@ -379,11 +404,10 @@ inline std::string dumpOptions(BatchedGemmOptions const& options)
379404 ss << " mIsStaticBatch=" << options.mIsStaticBatch << " ," << std::endl;
380405 ss << " mNumTokens=" << options.mNumTokens << " ," << std::endl;
381406 ss << " mRouteImpl=batchedGemm::RouteImpl(" << static_cast <int32_t >(options.mRouteImpl ) << " )," << std::endl;
407+ ss << " mRouteSfsImpl={batchedGemm::RouteImpl(" << static_cast <int32_t >(options.mRouteSfsImpl .value ()) << " )},"
408+ << std::endl;
382409 ss << " mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << " ," << std::endl;
383410 ss << " mFusedAct=" << options.mFusedAct << " ," << std::endl;
384- ss << " mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << " ," << std::endl;
385- ss << " mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << " ," << std::endl;
386- ss << " mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << " ," << std::endl;
387411 ss << " mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
388412 return ss.str ();
389413}
0 commit comments