@@ -393,6 +393,12 @@ void trtllm_fp8_block_scale_moe_launcher(
393
393
int32_t max_num_padded_tokens =
394
394
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount (
395
395
args.num_tokens , top_k, num_experts, tile_tokens_dim);
396
+ int32_t max_num_padded_tokens_gemm1 =
397
+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
398
+ max_num_padded_tokens, args.intermediate_size , btg::dtypeGetNumBits (args.mDtypeElt ));
399
+ int32_t max_num_padded_tokens_gemm2 =
400
+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
401
+ max_num_padded_tokens, args.hidden_size , btg::dtypeGetNumBits (args.mDtypeOut ));
396
402
Tensor total_num_padded_tokens = alloc_tensor ({1 }, dl_int32, routing_logits->device );
397
403
Tensor expanded_idx_to_permuted_idx =
398
404
alloc_tensor ({args.num_tokens * args.top_k }, dl_int32, routing_logits->device );
@@ -413,16 +419,16 @@ void trtllm_fp8_block_scale_moe_launcher(
413
419
// dl_float8_e4m3fn, hidden_states->device);
414
420
// Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size},
415
421
// dl_float8_e4m3fn, hidden_states->device);
416
- Tensor gemm1_output =
417
- alloc_tensor ({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states->device );
422
+ Tensor gemm1_output = alloc_tensor ({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8,
423
+ hidden_states->device );
418
424
Tensor gemm1_output_scale = alloc_tensor ({2 * intermediate_size / 128 , max_num_padded_tokens},
419
425
dl_float32, hidden_states->device );
420
- Tensor activation_output =
421
- alloc_tensor ({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states->device );
422
- Tensor activation_output_scale = alloc_tensor ({intermediate_size / 128 , max_num_padded_tokens},
423
- dl_float32, hidden_states->device );
424
- Tensor gemm2_output =
425
- alloc_tensor ({max_num_padded_tokens, args. hidden_size }, dl_bfloat16, hidden_states->device );
426
+ Tensor activation_output = alloc_tensor ({max_num_padded_tokens_gemm1, intermediate_size},
427
+ dl_uint8, hidden_states->device );
428
+ Tensor activation_output_scale = alloc_tensor (
429
+ {intermediate_size / 128 , max_num_padded_tokens_gemm1}, dl_float32, hidden_states->device );
430
+ Tensor gemm2_output = alloc_tensor ({max_num_padded_tokens_gemm2, args. hidden_size }, dl_bfloat16,
431
+ hidden_states->device );
426
432
427
433
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim (
428
434
args.num_tokens , args.top_k , args.num_experts , tile_tokens_dim);
@@ -519,7 +525,8 @@ void trtllm_fp8_block_scale_moe_launcher(
519
525
520
526
// setup workspace
521
527
workspace.total_num_padded_tokens = static_cast <int *>(total_num_padded_tokens->data );
522
- workspace.total_max_padded_tokens = max_num_padded_tokens;
528
+ workspace.total_max_padded_tokens =
529
+ std::max (max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
523
530
workspace.ProjUpTileN = tile_tokens_dim;
524
531
workspace.routing_expert_indexes = static_cast <int *>(expert_indexes->data );
525
532
workspace.permuted_idx_size = static_cast <int *>(total_num_padded_tokens->data );
@@ -764,6 +771,12 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
764
771
int32_t max_num_padded_tokens =
765
772
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount (
766
773
args.num_tokens , top_k, num_experts, tile_tokens_dim);
774
+ int32_t max_num_padded_tokens_gemm1 =
775
+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
776
+ max_num_padded_tokens, args.intermediate_size , btg::dtypeGetNumBits (args.mDtypeElt ));
777
+ int32_t max_num_padded_tokens_gemm2 =
778
+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
779
+ max_num_padded_tokens, args.hidden_size , btg::dtypeGetNumBits (args.mDtypeOut ));
767
780
Tensor total_num_padded_tokens = alloc_tensor ({1 }, dl_int32, hidden_states->device );
768
781
Tensor expanded_idx_to_permuted_idx =
769
782
alloc_tensor ({args.num_tokens , args.top_k }, dl_int32, hidden_states->device );
@@ -788,20 +801,20 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
788
801
// Tensor gemm1_output = alloc_tensor(
789
802
// {max_num_padded_tokens, gemm1_output_hidden},
790
803
// dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, hidden_states->device);
791
- Tensor gemm1_output = alloc_tensor ({max_num_padded_tokens , gemm1_output_hidden},
804
+ Tensor gemm1_output = alloc_tensor ({max_num_padded_tokens_gemm1 , gemm1_output_hidden},
792
805
dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8,
793
806
hidden_states->device );
794
807
795
808
Optional<Tensor> gemm1_output_scale = std::nullopt ;
796
809
if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) {
797
- int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize (max_num_padded_tokens ,
810
+ int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize (max_num_padded_tokens_gemm1 ,
798
811
intermediate_size / sf_vec_size);
799
812
// gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states->device);
800
813
gemm1_output_scale = alloc_tensor ({sf_size}, dl_uint8, hidden_states->device );
801
814
}
802
815
803
- Tensor gemm2_output =
804
- alloc_tensor ({max_num_padded_tokens, args. hidden_size }, dl_bfloat16, hidden_states->device );
816
+ Tensor gemm2_output = alloc_tensor ({max_num_padded_tokens_gemm2, args. hidden_size }, dl_bfloat16,
817
+ hidden_states->device );
805
818
806
819
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim (
807
820
args.num_tokens , args.top_k , args.num_experts , tile_tokens_dim);
@@ -958,7 +971,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
958
971
959
972
// setup workspace
960
973
workspace.total_num_padded_tokens = static_cast <int *>(total_num_padded_tokens->data );
961
- workspace.total_max_padded_tokens = max_num_padded_tokens;
974
+ workspace.total_max_padded_tokens =
975
+ std::max (max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
962
976
workspace.ProjUpTileN = tile_tokens_dim;
963
977
workspace.routing_expert_indexes = static_cast <int *>(expert_indices->data );
964
978
workspace.permuted_idx_size = static_cast <int *>(total_num_padded_tokens->data );
0 commit comments