@@ -61,8 +61,8 @@ Runner::Runner(int32_t tileTokensDim)
6161}
6262
6363void Runner::run (void * routingLogits, void * routingBias, int32_t numTokens, int32_t numExperts, int32_t topK,
64- int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, int32_t localNumExperts, float routedScalingFactor ,
65- int32_t * routingExpertIndexes, int32_t * expertCountHistogram, int32_t * permutedIdxSize,
64+ int32_t numFusedSharedExpert, int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, int32_t localNumExperts,
65+ float routedScalingFactor, int32_t * routingExpertIndexes, int32_t * expertCountHistogram, int32_t * permutedIdxSize,
6666 int32_t * expandedIdxToPermutedIdx, int32_t * permutedIdxToExpandedIdx, int32_t * permutedIdxToTokenIdx,
6767 void * expertWeights, int32_t * expertIds, int32_t * numTokensPerExpert, int32_t * ctaIdxXyToBatchIdx,
6868 int32_t * ctaIdxXyToMnLimit, int32_t * numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput,
@@ -76,6 +76,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
7676 routingData.mDtypeExpW = btg::Dtype::Bfloat16;
7777 routingData.mUsePdl = true ;
7878
79+ int32_t const totalExpertsPerToken = topK + numFusedSharedExpert;
80+
7981 // output:
8082 routingData.mPtrTopKPacked = routingExpertIndexes;
8183 routingData.mPtrExpertCounts = expertCountHistogram;
@@ -96,16 +98,35 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
9698 routingData.mPtrTopKIds = expertIds;
9799 routingData.mNumTokens = numTokens;
98100 routingData.mNumExperts = numExperts;
101+ routingData.mNumFusedSharedExperts = numFusedSharedExpert;
99102 routingData.mNumExpertGroups = nGroup;
100103 routingData.mNumLimitedGroups = topkGroup;
101104 routingData.mTopK = topK;
105+ routingData.mTotalExpertsPerToken = totalExpertsPerToken;
102106 routingData.mPaddingLog2 = computeLog2 (mTileTokensDim );
103107 routingData.mTileTokensDim = mTileTokensDim ;
104108 routingData.mLocalExpertsStartIdx = localExpertOffset;
105109 routingData.mLocalExpertsStrideLog2 = 0 ;
106110 routingData.mNumLocalExperts = localNumExperts;
107111 routingData.mRouteScale = routedScalingFactor;
108112 routingData.mUseRoutingSoftmax = false ;
113+
114+ // TODO Should these be passed directly instead? This does assume a constant number of experts per device
115+ int32_t const numDevices = numExperts / localNumExperts;
116+ int32_t const deviceIndex = localExpertOffset / localNumExperts;
117+ int32_t const baseTokensPerDevice = numTokens / numDevices;
118+ int32_t const remainingTokens = numTokens % numDevices;
119+
120+ if (deviceIndex < remainingTokens)
121+ {
122+ routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1 ) * deviceIndex;
123+ routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1 ;
124+ }
125+ else
126+ {
127+ routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice;
128+ routingData.mSharedExpertNumTokens = baseTokensPerDevice;
129+ }
109130 moe::dev::routing::routingDeepSeek::run (routingData, stream);
110131 }
111132 else if (routingMethodType == RoutingMethodType::Llama4)
@@ -115,6 +136,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
115136 {
116137 TLLM_LOG_WARNING (" For Llama routing method, nGroup/topkGroup is ignored, got %d/%d." , nGroup, topkGroup);
117138 }
139+ TLLM_CHECK_WITH_INFO (numFusedSharedExpert == 0 , " Llama routing method does not support fusing shared expert" );
140+
118141 moe::dev::routing::routingLlama4::Data routingData;
119142 routingData.mDtypeExpW = btg::Dtype::Bfloat16;
120143 routingData.mUsePdl = true ;
@@ -159,6 +182,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
159182 else if (routingMethodType == RoutingMethodType::Renormalize /* default */
160183 || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */ )
161184 {
185+ TLLM_CHECK_WITH_INFO (
186+ numFusedSharedExpert == 0 , " Renormalize routing method does not support fusing shared expert" );
187+
162188 moe::dev::routing::routingRenormalize::Data routingData;
163189
164190 //
@@ -434,6 +460,9 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
434460 moe::dev::convertsf::Data& convertSfData, moe::dev::activation::Data& activationData,
435461 moe::dev::finalize::Data& finalizeData)
436462{
463+ int32_t const totalNumExperts = args.num_experts + args.num_fused_shared_experts ;
464+ int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts ;
465+
437466 // Setup sf conversion data if needed
438467 convertSfData.inSfPtr = args.hidden_states_scale ;
439468 convertSfData.outSfPtr = workspace.hidden_states_scale_linear ;
@@ -452,7 +481,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
452481 activationData.inDqSfsPtr = workspace.gemm1_output_scale ;
453482 activationData.outDqSfsPtr = workspace.activation_output_scale ;
454483 activationData.innerDim = args.intermediate_size * 2 ;
455- activationData.topK = args. top_k ;
484+ activationData.topK = totalExpertsPerToken; // TODO Rename topK in activation data struct
456485 activationData.numTokens = args.num_tokens ;
457486 activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx ;
458487
@@ -479,8 +508,8 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
479508 }
480509 finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx ;
481510 finalizeData.numTokens = args.num_tokens ;
482- finalizeData.numExperts = args. num_experts ;
483- finalizeData.topK = args. top_k ;
511+ finalizeData.numExperts = totalNumExperts; // TODO Is this used?
512+ finalizeData.topK = totalExpertsPerToken; // TODO Rename topK in finalize data struct
484513 // We want to fuse unpadding into the finalize kernel, so we need to use the output hidden size.
485514 finalizeData.hiddenDim = args.valid_hidden_size .value_or (args.hidden_size );
486515 finalizeData.hiddenDimPadded = args.output_hidden_size .value_or (args.hidden_size );
@@ -490,12 +519,15 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
490519
491520std::tuple<int32_t , int32_t > Runner::getWorkspaceSizeInBytes (MoERunnerArgs const & args, int64_t configIndex) const
492521{
522+ int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts ;
523+ int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts ;
524+
493525 auto const & config = mPassingConfigs [configIndex];
494526
495- auto workspace_size_fc1 = static_cast <int32_t >(mPermuteGemm1 .getWorkspaceSizeInBytes (args. top_k , args. hidden_size ,
496- args.intermediate_size , args.local_num_experts , args.num_tokens , config.gemm1Config ));
497- auto workspace_size_fc2 = static_cast <int32_t >(mGemm2 .getWorkspaceSizeInBytes (args. top_k , args. hidden_size ,
498- args.intermediate_size , args.local_num_experts , args.num_tokens , config.gemm2Config ));
527+ auto workspace_size_fc1 = static_cast <int32_t >(mPermuteGemm1 .getWorkspaceSizeInBytes (totalExpertsPerToken ,
528+ args.hidden_size , args.intermediate_size , totalLocalExperts , args.num_tokens , config.gemm1Config ));
529+ auto workspace_size_fc2 = static_cast <int32_t >(mGemm2 .getWorkspaceSizeInBytes (totalExpertsPerToken ,
530+ args.hidden_size , args.intermediate_size , totalLocalExperts , args.num_tokens , config.gemm2Config ));
499531 return std::make_tuple (workspace_size_fc1, workspace_size_fc2);
500532}
501533
@@ -530,7 +562,6 @@ std::vector<int64_t> Runner::getValidConfigIndices(int32_t topK, int32_t hiddenS
530562int64_t Runner::getDefaultValidConfigIndex (int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
531563 int32_t numLocalExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const
532564{
533-
534565 int32_t indexGemm1 = mPermuteGemm1 .getDefaultValidConfigIndex (
535566 topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize);
536567 int32_t indexGemm2 = mGemm2 .getDefaultValidConfigIndex (
@@ -553,14 +584,17 @@ void Runner::run(
553584 sync_check_cuda_error (stream);
554585 setOpsData (args, workspace, convertSfData, activationData, finalizeData);
555586
587+ int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts ;
588+ int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts ;
589+
556590 void * hidden_states_scale_linear{args.hidden_states_scale };
557591
558592 auto const & config = mPassingConfigs [configIndex];
559593
560594 mPermuteGemm1 .run (args.hidden_states , hidden_states_scale_linear, args.gemm1_weights , args.gemm1_weights_scale ,
561595 workspace.expert_weights , args.output1_scales_scalar , args.output1_scales_gate_scalar , args.gemm1_bias ,
562596 args.gemm1_alpha , args.gemm1_beta , args.gemm1_clamp_limit , workspace.gemm1_output , workspace.gemm1_output_scale ,
563- args. top_k , args.hidden_size , args.intermediate_size , args. local_num_experts , args.num_tokens ,
597+ totalExpertsPerToken , args.hidden_size , args.intermediate_size , totalLocalExperts , args.num_tokens ,
564598 workspace.permuted_idx_to_token_idx , workspace.num_non_exiting_ctas , workspace.total_num_padded_tokens ,
565599 workspace.cta_idx_xy_to_batch_idx , workspace.cta_idx_xy_to_mn_limit , workspace.bmm1_workspace ,
566600 args.mUseRoutingScalesOnInput , device, stream, config.gemm1Config ,
@@ -581,11 +615,11 @@ void Runner::run(
581615
582616 // Run gemm2
583617 mGemm2 .run (gemm2_input, gemm2_input_scale, args.gemm2_weights , args.gemm2_weights_scale , args.output2_scales_scalar ,
584- args.gemm2_bias , workspace.gemm2_output , workspace.gemm2_output_scale , args. top_k ,
585- args.output_hidden_size .value_or (args.hidden_size ), args.intermediate_size , args.local_num_experts ,
586- args. num_tokens , workspace.num_non_exiting_ctas , workspace.total_num_padded_tokens ,
587- workspace.cta_idx_xy_to_batch_idx , workspace. cta_idx_xy_to_mn_limit , workspace.bmm2_workspace , device, stream,
588- config. gemm2Config , args.valid_hidden_size .value_or (args.hidden_size ),
618+ args.gemm2_bias , workspace.gemm2_output , workspace.gemm2_output_scale , totalExpertsPerToken ,
619+ args.output_hidden_size .value_or (args.hidden_size ), args.intermediate_size , totalLocalExperts, args.num_tokens ,
620+ workspace. num_non_exiting_ctas , workspace.total_num_padded_tokens , workspace.cta_idx_xy_to_batch_idx ,
621+ workspace.cta_idx_xy_to_mn_limit , workspace.bmm2_workspace , device, stream, config. gemm2Config ,
622+ args.valid_hidden_size .value_or (args.hidden_size ),
589623 args.valid_intermediate_size .value_or (args.intermediate_size ));
590624
591625 // Run finalize
0 commit comments