Skip to content

Commit 7c6df0e

Browse files
authored
[None][feat] fuse shared to sparse experts in TRT-LLM Gen MoE (#11143)
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
1 parent be20949 commit 7c6df0e

File tree

14 files changed

+458
-134
lines changed

14 files changed

+458
-134
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,31 @@ __global__ void routingMainKernel(KernelParams params)
269269
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};
270270

271271
// write expert idx out already
272-
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
272+
auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx;
273+
auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx;
273274
if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr)
274275
{
275276
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore), static_cast<int16_t>(expertIdx)};
276277
params.mPtrTopKPacked[idxTopK] = packedScore;
277278
}
278279

280+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr)
281+
{
282+
PackedScoreIdx<OutputT> packedScore{
283+
static_cast<OutputT>(1.0F), static_cast<int16_t>(params.mNumExperts + laneIdx)};
284+
params.mPtrTopKPacked[idxShared] = packedScore;
285+
}
286+
279287
if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr)
280288
{
281289
params.mPtrTopKWeights[idxTopK] = finalScore;
282290
}
291+
292+
// Write score of 1.0 for shared expert if enabled
293+
if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr)
294+
{
295+
params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F);
296+
}
283297
}
284298
}
285299
}
@@ -629,9 +643,15 @@ void run(Data& data, void* stream)
629643
"If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required");
630644
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
631645
int const numBlocks = data.mNumTokens;
632-
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
646+
int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
647+
int const topK = data.mTopK + data.mNumFusedSharedExperts;
648+
int const numThreadsHist = getMaxNumExperts(numExperts);
649+
int const maxNumTopExperts = getMaxNumExperts(numExperts);
650+
651+
// Number of threads in the cluster.
652+
int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster;
633653

634-
bool const useSingleCluster = data.mNumTokens <= 1024;
654+
bool const useSingleCluster = data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster;
635655
if (!useSingleCluster)
636656
{
637657
// Reset the global histograms (not used in single-cluster code path).
@@ -658,15 +678,15 @@ void run(Data& data, void* stream)
658678
int const numBlocksCoop = smCount - 8;
659679

660680
// Maximum number of tokens supported by the kernel using a cooperative launch.
661-
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
681+
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK;
662682
if (data.mPtrTopKIds == nullptr)
663683
{
664684
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts,
665685
"Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts);
666686
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount,
667687
"Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount);
668-
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
669-
MaxSupportedTopExperts, data.mTopK);
688+
TLLM_CHECK_WITH_INFO(topK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
689+
MaxSupportedTopExperts, topK);
670690

671691
// Routing needs to be executed - validate routing kernel constraints
672692
if (data.mNumExpertGroups > 1)
@@ -690,6 +710,16 @@ void run(Data& data, void* stream)
690710
data.mNumExpertGroups);
691711
TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.",
692712
data.mNumExperts);
713+
714+
TLLM_CHECK_WITH_INFO(data.mNumFusedSharedExperts <= WarpSize,
715+
"Number of fused shared experts (%d must be less than warp size.", WarpSize);
716+
717+
if (data.mNumFusedSharedExperts > 0)
718+
{
719+
// Disabling due to lack of testing
720+
// TLLM_CHECK_WITH_INFO(
721+
// data.mPtrTopKPacked == nullptr, "Shared expert fusion is not compatible with packed scores");
722+
}
693723
}
694724

695725
int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
@@ -707,6 +737,14 @@ void run(Data& data, void* stream)
707737
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false);
708738
}
709739

740+
if (data.mNumFusedSharedExperts > 0)
741+
{
742+
data.mNumExperts += data.mNumFusedSharedExperts;
743+
data.mTopK += data.mNumFusedSharedExperts;
744+
data.mNumLocalExperts += data.mNumFusedSharedExperts;
745+
// data.mLocalExpertsStartIdx += data.mNumFusedSharedExperts;
746+
}
747+
710748
if (data.mPtrPermutedIdxSize != nullptr)
711749
{
712750
if (useSingleCluster)
@@ -725,7 +763,7 @@ void run(Data& data, void* stream)
725763
}
726764
else
727765
{
728-
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
766+
const int32_t expandedIdxSize = data.mNumTokens * topK;
729767
const int32_t histogramEltsPerBlock = 8 * numThreadsHist;
730768
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist;
731769

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ struct DataBase
107107
int32_t mLocalExpertsStartIdx;
108108
int32_t mLocalExpertsStrideLog2;
109109
int32_t mNumLocalExperts;
110+
111+
/// For fused shared expert
112+
int32_t mNumFusedSharedExperts;
113+
int32_t mSharedExpertTokenOffset;
114+
int32_t mSharedExpertNumTokens;
115+
int32_t mTotalExpertsPerToken;
110116
};
111117

112118
template <typename InputT_, typename OutputT_, int MaxNumExperts_, bool isPow2_, bool UsePdl_>
@@ -141,6 +147,11 @@ struct KernelParamsBase
141147
int32_t mLocalExpertsStrideLog2 = 0;
142148
int32_t mNumLocalExperts = 0;
143149

150+
int32_t mNumFusedSharedExperts;
151+
int32_t mSharedExpertTokenOffset;
152+
int32_t mSharedExpertNumTokens;
153+
int32_t mTotalExpertsPerToken;
154+
144155
// Public initialization function - make it a template to accept different Data types
145156
template <typename DataType>
146157
void setBaseParams(DataType const& data)
@@ -165,6 +176,11 @@ struct KernelParamsBase
165176
mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
166177
mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
167178
mNumLocalExperts = data.mNumLocalExperts;
179+
180+
mNumFusedSharedExperts = data.mNumFusedSharedExperts;
181+
mSharedExpertTokenOffset = data.mSharedExpertTokenOffset;
182+
mSharedExpertNumTokens = data.mSharedExpertNumTokens;
183+
mTotalExpertsPerToken = data.mTotalExpertsPerToken;
168184
}
169185
};
170186

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ Runner::Runner(int32_t tileTokensDim)
6161
}
6262

6363
void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK,
64-
int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, int32_t localNumExperts, float routedScalingFactor,
65-
int32_t* routingExpertIndexes, int32_t* expertCountHistogram, int32_t* permutedIdxSize,
64+
int32_t numFusedSharedExpert, int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, int32_t localNumExperts,
65+
float routedScalingFactor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram, int32_t* permutedIdxSize,
6666
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx,
6767
void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx,
6868
int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput,
@@ -76,6 +76,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
7676
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
7777
routingData.mUsePdl = true;
7878

79+
int32_t const totalExpertsPerToken = topK + numFusedSharedExpert;
80+
7981
// output:
8082
routingData.mPtrTopKPacked = routingExpertIndexes;
8183
routingData.mPtrExpertCounts = expertCountHistogram;
@@ -96,16 +98,35 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
9698
routingData.mPtrTopKIds = expertIds;
9799
routingData.mNumTokens = numTokens;
98100
routingData.mNumExperts = numExperts;
101+
routingData.mNumFusedSharedExperts = numFusedSharedExpert;
99102
routingData.mNumExpertGroups = nGroup;
100103
routingData.mNumLimitedGroups = topkGroup;
101104
routingData.mTopK = topK;
105+
routingData.mTotalExpertsPerToken = totalExpertsPerToken;
102106
routingData.mPaddingLog2 = computeLog2(mTileTokensDim);
103107
routingData.mTileTokensDim = mTileTokensDim;
104108
routingData.mLocalExpertsStartIdx = localExpertOffset;
105109
routingData.mLocalExpertsStrideLog2 = 0;
106110
routingData.mNumLocalExperts = localNumExperts;
107111
routingData.mRouteScale = routedScalingFactor;
108112
routingData.mUseRoutingSoftmax = false;
113+
114+
// TODO Should these be passed directly instead? This does assume a constant number of experts per device
115+
int32_t const numDevices = numExperts / localNumExperts;
116+
int32_t const deviceIndex = localExpertOffset / localNumExperts;
117+
int32_t const baseTokensPerDevice = numTokens / numDevices;
118+
int32_t const remainingTokens = numTokens % numDevices;
119+
120+
if (deviceIndex < remainingTokens)
121+
{
122+
routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1) * deviceIndex;
123+
routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1;
124+
}
125+
else
126+
{
127+
routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice;
128+
routingData.mSharedExpertNumTokens = baseTokensPerDevice;
129+
}
109130
moe::dev::routing::routingDeepSeek::run(routingData, stream);
110131
}
111132
else if (routingMethodType == RoutingMethodType::Llama4)
@@ -115,6 +136,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
115136
{
116137
TLLM_LOG_WARNING("For Llama routing method, nGroup/topkGroup is ignored, got %d/%d.", nGroup, topkGroup);
117138
}
139+
TLLM_CHECK_WITH_INFO(numFusedSharedExpert == 0, "Llama routing method does not support fusing shared expert");
140+
118141
moe::dev::routing::routingLlama4::Data routingData;
119142
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
120143
routingData.mUsePdl = true;
@@ -159,6 +182,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
159182
else if (routingMethodType == RoutingMethodType::Renormalize /* default */
160183
|| routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */)
161184
{
185+
TLLM_CHECK_WITH_INFO(
186+
numFusedSharedExpert == 0, "Renormalize routing method does not support fusing shared expert");
187+
162188
moe::dev::routing::routingRenormalize::Data routingData;
163189

164190
//
@@ -434,6 +460,9 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
434460
moe::dev::convertsf::Data& convertSfData, moe::dev::activation::Data& activationData,
435461
moe::dev::finalize::Data& finalizeData)
436462
{
463+
int32_t const totalNumExperts = args.num_experts + args.num_fused_shared_experts;
464+
int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts;
465+
437466
// Setup sf conversion data if needed
438467
convertSfData.inSfPtr = args.hidden_states_scale;
439468
convertSfData.outSfPtr = workspace.hidden_states_scale_linear;
@@ -452,7 +481,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
452481
activationData.inDqSfsPtr = workspace.gemm1_output_scale;
453482
activationData.outDqSfsPtr = workspace.activation_output_scale;
454483
activationData.innerDim = args.intermediate_size * 2;
455-
activationData.topK = args.top_k;
484+
activationData.topK = totalExpertsPerToken; // TODO Rename topK in activation data struct
456485
activationData.numTokens = args.num_tokens;
457486
activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
458487

@@ -479,8 +508,8 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
479508
}
480509
finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
481510
finalizeData.numTokens = args.num_tokens;
482-
finalizeData.numExperts = args.num_experts;
483-
finalizeData.topK = args.top_k;
511+
finalizeData.numExperts = totalNumExperts; // TODO Is this used?
512+
finalizeData.topK = totalExpertsPerToken; // TODO Rename topK in finalize data struct
484513
// We want to fuse unpadding into the finalize kernel, so we need to use the output hidden size.
485514
finalizeData.hiddenDim = args.valid_hidden_size.value_or(args.hidden_size);
486515
finalizeData.hiddenDimPadded = args.output_hidden_size.value_or(args.hidden_size);
@@ -490,12 +519,15 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
490519

491520
std::tuple<int32_t, int32_t> Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args, int64_t configIndex) const
492521
{
522+
int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts;
523+
int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts;
524+
493525
auto const& config = mPassingConfigs[configIndex];
494526

495-
auto workspace_size_fc1 = static_cast<int32_t>(mPermuteGemm1.getWorkspaceSizeInBytes(args.top_k, args.hidden_size,
496-
args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm1Config));
497-
auto workspace_size_fc2 = static_cast<int32_t>(mGemm2.getWorkspaceSizeInBytes(args.top_k, args.hidden_size,
498-
args.intermediate_size, args.local_num_experts, args.num_tokens, config.gemm2Config));
527+
auto workspace_size_fc1 = static_cast<int32_t>(mPermuteGemm1.getWorkspaceSizeInBytes(totalExpertsPerToken,
528+
args.hidden_size, args.intermediate_size, totalLocalExperts, args.num_tokens, config.gemm1Config));
529+
auto workspace_size_fc2 = static_cast<int32_t>(mGemm2.getWorkspaceSizeInBytes(totalExpertsPerToken,
530+
args.hidden_size, args.intermediate_size, totalLocalExperts, args.num_tokens, config.gemm2Config));
499531
return std::make_tuple(workspace_size_fc1, workspace_size_fc2);
500532
}
501533

@@ -530,7 +562,6 @@ std::vector<int64_t> Runner::getValidConfigIndices(int32_t topK, int32_t hiddenS
530562
int64_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
531563
int32_t numLocalExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const
532564
{
533-
534565
int32_t indexGemm1 = mPermuteGemm1.getDefaultValidConfigIndex(
535566
topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize);
536567
int32_t indexGemm2 = mGemm2.getDefaultValidConfigIndex(
@@ -553,14 +584,17 @@ void Runner::run(
553584
sync_check_cuda_error(stream);
554585
setOpsData(args, workspace, convertSfData, activationData, finalizeData);
555586

587+
int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts;
588+
int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts;
589+
556590
void* hidden_states_scale_linear{args.hidden_states_scale};
557591

558592
auto const& config = mPassingConfigs[configIndex];
559593

560594
mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale,
561595
workspace.expert_weights, args.output1_scales_scalar, args.output1_scales_gate_scalar, args.gemm1_bias,
562596
args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, workspace.gemm1_output_scale,
563-
args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens,
597+
totalExpertsPerToken, args.hidden_size, args.intermediate_size, totalLocalExperts, args.num_tokens,
564598
workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens,
565599
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace,
566600
args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config,
@@ -581,11 +615,11 @@ void Runner::run(
581615

582616
// Run gemm2
583617
mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar,
584-
args.gemm2_bias, workspace.gemm2_output, workspace.gemm2_output_scale, args.top_k,
585-
args.output_hidden_size.value_or(args.hidden_size), args.intermediate_size, args.local_num_experts,
586-
args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens,
587-
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream,
588-
config.gemm2Config, args.valid_hidden_size.value_or(args.hidden_size),
618+
args.gemm2_bias, workspace.gemm2_output, workspace.gemm2_output_scale, totalExpertsPerToken,
619+
args.output_hidden_size.value_or(args.hidden_size), args.intermediate_size, totalLocalExperts, args.num_tokens,
620+
workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx,
621+
workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, config.gemm2Config,
622+
args.valid_hidden_size.value_or(args.hidden_size),
589623
args.valid_intermediate_size.value_or(args.intermediate_size));
590624

591625
// Run finalize

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,13 @@ class Runner
147147
explicit Runner(int32_t tileTokensDim);
148148

149149
void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK,
150-
int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset, int32_t localNumExperts,
151-
float routedScalingFactor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram,
152-
int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
153-
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert,
154-
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
155-
batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput, bool useDeepSeekFp8,
156-
RoutingMethodType routingMethodType, cudaStream_t stream);
150+
int32_t numFusedSharedExpert, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset,
151+
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
152+
int32_t* expertCountHistogram, int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx,
153+
int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds,
154+
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
155+
int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput,
156+
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);
157157

158158
private:
159159
int32_t mTileTokensDim;
@@ -268,6 +268,7 @@ struct MoERunnerArgs
268268

269269
int32_t num_tokens{0};
270270
int32_t num_experts{0};
271+
int32_t num_fused_shared_experts{0};
271272
// Hidden dimension input of MoE block. It might be padded.
272273
int32_t hidden_size{0};
273274
// Hidden dimension output of MoE block. It might be padded.

0 commit comments

Comments
 (0)