From 5a9f69724a38123a65464c95a116e695d83f3ce4 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 5 Aug 2025 15:24:30 -0700 Subject: [PATCH 01/12] add mxfp8 x mxfp4 cutlass fused moe --- .../cutlass_fused_moe_instantiation.cu | 2 + .../cutlass_fused_moe_kernels.cuh | 326 ++++++--- .../flashinfer_cutlass_fused_moe_sm100_ops.cu | 174 +++-- csrc/nv_internal/cpp/kernels/quantization.cu | 212 ++---- .../kernels/cutlass_kernels/include/common.h | 2 +- .../include/moe_gemm_kernels.h | 20 +- .../cutlass_kernels/include/moe_kernels.h | 90 ++- .../moe_gemm_tma_ws_mixed_input_launcher.inl | 30 +- .../moe_gemm/moe_gemm_kernels_bf16_fp4.cu | 24 + .../moe_gemm/moe_gemm_kernels_fp16_fp4.cu | 22 + .../moe_gemm/moe_gemm_template_dispatch.h | 103 ++- ...emm_template_dispatch_tma_ws_mixed_dtype.h | 14 +- .../moe_tma_warp_specialized_traits.h | 5 +- .../tensorrt_llm/kernels/quantization.cuh | 653 +++++++----------- .../tensorrt_llm/kernels/quantization.h | 42 +- csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp | 45 +- .../tensorrt_llm/thop/fp4Quantize.cpp | 44 +- .../tensorrt_llm/thop/fp4Quantize.h | 3 +- .../tensorrt_llm/thop/fp8Quantize.cpp | 77 ++- .../tensorrt_llm/thop/fp8Quantize.h | 9 +- csrc/trtllm_allreduce_fusion.cu | 4 +- csrc/trtllm_fused_moe_kernel_launcher.cu | 4 +- csrc/trtllm_moe_allreduce_fusion.cu | 4 +- flashinfer/__init__.py | 4 +- flashinfer/comm/__init__.py | 2 +- flashinfer/comm/trtllm_ar.py | 10 +- flashinfer/fp4_quantization.py | 35 +- flashinfer/fp8_quantization.py | 5 + flashinfer/fused_moe/core.py | 23 +- flashinfer/fused_moe/utils.py | 4 +- .../comm/trtllm_allreduce_fusion.cuh | 10 +- .../comm/trtllm_moe_allreduce_fusion.cuh | 10 +- tests/test_fp4_quantize.py | 8 +- tests/test_trtllm_allreduce_fusion.py | 6 +- tests/test_trtllm_cutlass_fused_moe.py | 125 +++- tests/test_trtllm_moe_allreduce_fusion.py | 6 +- 36 files changed, 1225 insertions(+), 932 deletions(-) create mode 100644 csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu create mode 100644 csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index a81691cf9..f20729f16 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -45,11 +45,13 @@ template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>; +template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif }; // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 6c8789e9a..f9151bff4 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -286,7 +286,6 @@ void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, smart_routing, cluster_rank, cluster_size, num_experts_smem); } - template __global__ void fusedBuildExpertMapsSortFirstTokenKernel( int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, @@ -963,13 +962,13 @@ __device__ auto quantizePackedFPXValue( TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { constexpr bool is_fp8 = std::is_same_v; - static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / CVT_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v || std::is_same_v); - static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); + static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD); PackedVec packed_vec{}; - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { packed_vec.elts[i].x = static_cast(post_act_val[i * 2 + 0]); packed_vec.elts[i].y = static_cast(post_act_val[i * 2 + 1]); } @@ -980,10 +979,9 @@ __device__ auto quantizePackedFPXValue( // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this // expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + auto sf_out = cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4); + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor auto func = [&]() { @@ -1020,18 +1018,16 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this // expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + auto sf_out = cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4); + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); if (sf_out) { if (input_sf) { auto const sf_in = - cvt_quant_to_fp4_get_sf_out_offset( + cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED_128x4); + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED); *sf_out = *sf_in; } else { *sf_out = 0x00; @@ -1127,7 +1123,12 @@ __device__ void computeTmaWarpSpecializedInputStrides( if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride( TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::StrideSFA{}, - cute::make_shape(gemm_n, gemm_k / 128, 1)); + cute::make_shape(gemm_n, + gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size), + 1)); } } @@ -1150,8 +1151,14 @@ __device__ void computeTmaWarpSpecializedInputPointers( safe_inc_ptr(output, num_tokens_before_expert * gemm_n); } if (layout_info.int4_groupwise_params.enabled) { - layout_info.int4_groupwise_params.ptr_s_a[out_idx] = - safe_inc_ptr(w4a8_weight_scale, expert * (gemm_n * gemm_k / 128)); + // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 bytes + layout_info.int4_groupwise_params.ptr_s_a[out_idx] = safe_inc_ptr(w4a8_weight_scale, + expert + * (gemm_n * gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size * 2 + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size))); + } } @@ -1453,7 +1460,7 @@ __global__ void expandInputRowsKernel( : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; constexpr int64_t ELEM_PER_THREAD = (is_nvfp4 || is_mxfp8) - ? CVT_FP4_ELTS_PER_THREAD + ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits::value); // This should be VecSize * 4 elements @@ -1977,16 +1984,67 @@ void finalizeMoeRoutingKernelLauncher( // INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); // #endif +// ============================== Activation Adaptors ================================= +template