diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index a81691cf9..f20729f16 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -45,11 +45,13 @@ template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>; +template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif }; // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index df1a0ea70..231063c05 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -284,7 +284,6 @@ void buildMinLatencyActiveExpertMaps( num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, smart_routing, cluster_rank, cluster_size, num_experts_smem); } - template __global__ void fusedBuildExpertMapsSortFirstTokenKernel( int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, @@ -963,13 +962,13 @@ __device__ auto quantizePackedFPXValue( TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { constexpr bool is_fp8 = std::is_same_v; - static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / CVT_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v || std::is_same_v); - static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); + static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD); PackedVec packed_vec{}; - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { packed_vec.elts[i].x = static_cast(post_act_val[i * 2 + 0]); packed_vec.elts[i].y = static_cast(post_act_val[i * 2 + 1]); } @@ -980,10 +979,11 @@ __device__ auto quantizePackedFPXValue( // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this // expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4); + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); // Do the conversion and set the output and scaling factor auto func = [&]() { @@ -1020,18 +1020,18 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this // expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED_128x4); + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); if (sf_out) { if (input_sf) { - auto const sf_in = - cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED_128x4); + auto const sf_in = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED_128x4); *sf_out = *sf_in; } else { *sf_out = 0x00; @@ -1127,7 +1127,13 @@ __device__ void computeTmaWarpSpecializedInputStrides( if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride( TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::StrideSFA{}, - cute::make_shape(gemm_n, gemm_k / 128, 1)); + cute::make_shape( + gemm_n, + gemm_k / + (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size), + 1)); } } @@ -1150,8 +1156,15 @@ __device__ void computeTmaWarpSpecializedInputPointers( safe_inc_ptr(output, num_tokens_before_expert * gemm_n); } if (layout_info.int4_groupwise_params.enabled) { - layout_info.int4_groupwise_params.ptr_s_a[out_idx] = - safe_inc_ptr(w4a8_weight_scale, expert * (gemm_n * gemm_k / 128)); + // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 + // bytes + layout_info.int4_groupwise_params.ptr_s_a[out_idx] = safe_inc_ptr( + w4a8_weight_scale, + expert * + (gemm_n * gemm_k / + (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size * 2 + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size))); } } @@ -1453,7 +1466,7 @@ __global__ void expandInputRowsKernel( : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; constexpr int64_t ELEM_PER_THREAD = (is_nvfp4 || is_mxfp8) - ? CVT_FP4_ELTS_PER_THREAD + ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits::value); // This should be VecSize * 4 elements @@ -1977,16 +1990,62 @@ void finalizeMoeRoutingKernelLauncher( // INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); // #endif +// ============================== Activation Adaptors ================================= +template