Skip to content

Commit 9cae727

Browse files
yumin066rosenrodt
andauthored
[https://nvbugs/5726962][feat] Apply fusion for W4AFP8_AWQ MoE (#9838)
Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com> Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent 6b8ae6f commit 9cae727

File tree

9 files changed

+792
-336
lines changed

9 files changed

+792
-336
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ struct QuantParams
315315
{
316316
struct GroupwiseGemmInputs
317317
{
318-
void const* act_scales = nullptr;
318+
bool use_per_expert_act_scale = false;
319+
void const* act_scales = nullptr; // (1 or num_experts_per_node, hidden_size or intermediate_size)
319320
void const* weight_scales = nullptr;
320321
void const* weight_zeros = nullptr;
321322
float const* alpha = nullptr;
@@ -401,12 +402,15 @@ struct QuantParams
401402
static QuantParams GroupWise(int group_size, void const* fc1_weight_scales, void const* fc2_weight_scales,
402403
void const* fc1_activation_scales = nullptr, void const* fc2_activation_scales = nullptr,
403404
void const* fc1_weight_zeros = nullptr, void const* fc2_weight_zeros = nullptr,
404-
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr)
405+
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr, bool fc1_use_per_expert_act_scale = false,
406+
bool fc2_use_per_expert_act_scale = false)
405407
{
406408
QuantParams qp;
407409
qp.groupwise.group_size = group_size;
408-
qp.groupwise.fc1 = {fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
409-
qp.groupwise.fc2 = {fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
410+
qp.groupwise.fc1
411+
= {fc1_use_per_expert_act_scale, fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
412+
qp.groupwise.fc2
413+
= {fc2_use_per_expert_act_scale, fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
410414
return qp;
411415
}
412416

@@ -646,7 +650,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
646650
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
647651
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
648652
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
649-
int* num_active_experts_per, int* active_expert_global_ids);
653+
int* num_active_experts_per, int* active_expert_global_ids, void const* fc2_prequant_scale = nullptr);
650654

651655
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
652656
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@@ -803,6 +807,16 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
803807
bool min_latency_mode, bool use_awq);
804808

805809
private:
810+
static bool useAwq(cutlass_kernels::QuantParams const& quant_params)
811+
{
812+
return quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
813+
}
814+
815+
static bool usePrequantScaleKernel(cutlass_kernels::QuantParams const& quant_params)
816+
{
817+
return useAwq(quant_params) && !std::is_same_v<T, WeightType>;
818+
}
819+
806820
bool mayHaveDifferentGEMMOutputType() const
807821
{
808822
// We just check if its supported because we need to know when calculating workspace size
@@ -813,13 +827,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
813827
bool mayHaveFinalizeFused() const
814828
{
815829
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
816-
&& !use_w4_groupwise;
830+
&& !use_wfp4a16;
817831
}
818832

819833
static bool mayHaveFinalizeFused(int sm)
820834
{
821835
using RunnerType = decltype(moe_gemm_runner_);
822-
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
836+
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_wfp4a16;
823837
}
824838

825839
// TODO: This should eventually take the quant params to give more flexibility
@@ -866,7 +880,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
866880

867881
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
868882
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
869-
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
883+
cudaStream_t stream, QuantParams const& quant_params, int64_t* expert_first_token_offset = nullptr,
884+
int const num_experts_per_node = 0);
870885

871886
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
872887
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ namespace cutlass_kernels_oss
2828
{
2929
using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
3030
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
31-
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
32-
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
31+
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
32+
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
33+
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
3334
cutlass::WeightOnlyQuantOp QuantOp>
3435
void sm90_generic_mixed_moe_gemm_kernelLauncher(
3536
tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "cutlass/util/tensor_view_io.h"
4646

4747
#include "cutlass_extensions/compute_occupancy.h"
48+
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
4849
#include "cutlass_extensions/epilogue_helpers.h"
4950
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
5051
#include "cutlass_extensions/gemm_configs.h"
@@ -71,11 +72,12 @@ namespace cutlass_kernels_oss
7172
using namespace tensorrt_llm::kernels::cutlass_kernels;
7273
namespace tk = tensorrt_llm::common;
7374
namespace tkc = tensorrt_llm::cutlass_extensions;
75+
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
7476

7577
using namespace cute;
7678

77-
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
78-
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
79+
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
80+
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
7981
cutlass::WeightOnlyQuantOp QuantOp>
8082
void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
8183
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
@@ -85,6 +87,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
8587
/////////////////////////////////////////////////////////////////////////////////////////////////
8688
/// GEMM kernel configurations
8789
/////////////////////////////////////////////////////////////////////////////////////////////////
90+
static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE,
91+
"Unimplemented fusion provided to TMA WS Mixed MoE gemm launcher");
92+
constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE;
8893

8994
// A matrix configuration
9095
using ElementA = typename TllmToCutlassTypeAdapter<T>::type;
@@ -129,13 +134,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
129134
using ElementD = ElementC;
130135
using LayoutD = LayoutC;
131136
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
137+
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
138+
using ElementBias = ElementFinalOutput;
139+
using ElementRouterScales = float;
132140

133141
// Core kernel configurations
134142
using ElementAccumulator = float; // Element type for internal accumulation
135143
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
136144
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
137145
using TileShape = CTAShape; // Threadblock-level tile size
138146
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
147+
148+
using EpilogueFusionOp = cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter<
149+
typename cutlass::layout::LayoutTranspose<LayoutD>::type, ElementFinalOutput, ElementAccumulator, ElementBias,
150+
ElementRouterScales>;
151+
139152
using KernelSchedule
140153
= std::conditional_t<std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecializedPingpong>,
141154
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
@@ -145,12 +158,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
145158
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
146159
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>; // Epilogue to launch
147160

148-
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
161+
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
162+
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
163+
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
164+
AlignmentC, void, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD, EpilogueSchedule,
165+
EpilogueFusionOp>::CollectiveOp;
166+
167+
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
149168
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
150169
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
151170
AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD,
152171
EpilogueSchedule>::CollectiveOp;
153172

173+
using CollectiveEpilogue
174+
= std::conditional_t<IsFinalizeFusion, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
175+
154176
// =========================================================== MIXED INPUT WITH SCALES
155177
// =========================================================================== The Scale information must get paired
156178
// with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the
@@ -175,20 +197,56 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
175197
Args arguments;
176198

177199
decltype(arguments.epilogue.thread) fusion_args;
178-
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
179-
fusion_args.beta = 0;
180-
fusion_args.alpha_ptr = nullptr;
181-
fusion_args.beta_ptr = nullptr;
182-
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
183-
fusion_args.beta_ptr_array = nullptr;
184-
// One alpha and beta per each group
185-
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
186-
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
187200

188201
cutlass::KernelHardwareInfo hw_info;
189202
hw_info.device_id = 0;
190203
hw_info.sm_count = sm_count_;
191204

205+
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
206+
using EpilogueScalars = decltype(EpilogueArguments{}.thread);
207+
EpilogueScalars epilogue_scalars = [&]
208+
{
209+
if constexpr (IsFinalizeFusion)
210+
{
211+
auto epi_params = hopper_inputs.fused_finalize_epilogue;
212+
return EpilogueScalars{ElementAccumulator(1), nullptr, hopper_inputs.alpha_scale_ptr_array,
213+
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */
214+
reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), Stride<_1, _0, int64_t>{}, /* bias */
215+
epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */
216+
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output),
217+
epi_params.stride_final_output_transposed, epi_params.ptr_source_token_index,
218+
epi_params.num_rows_in_final_output, epi_params.shape_override, epi_params.use_reduction};
219+
}
220+
else
221+
{
222+
return EpilogueScalars{};
223+
}
224+
}();
225+
226+
EpilogueArguments epilogue_args = [&]
227+
{
228+
if constexpr (IsFinalizeFusion)
229+
{
230+
return EpilogueArguments{epilogue_scalars, nullptr, nullptr, nullptr, nullptr};
231+
}
232+
else
233+
{
234+
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
235+
fusion_args.beta = 0;
236+
fusion_args.alpha_ptr = nullptr;
237+
fusion_args.beta_ptr = nullptr;
238+
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
239+
fusion_args.beta_ptr_array = nullptr;
240+
// One alpha and beta per each group
241+
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
242+
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
243+
244+
return EpilogueArguments{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
245+
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
246+
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)};
247+
}
248+
}();
249+
192250
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped,
193251
{inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr},
194252
{reinterpret_cast<ElementB const**>(hopper_inputs.ptr_weight),
@@ -197,10 +255,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
197255
reinterpret_cast<StrideA*>(hopper_inputs.stride_act),
198256
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
199257
reinterpret_cast<StrideS*>(hopper_inputs.int4_groupwise_params.stride_s_a), group_size},
200-
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
201-
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
202-
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)},
203-
hw_info};
258+
epilogue_args, hw_info};
204259

205260
assert(group_size == int(inputs.groupwise_quant_group_size));
206261
if (workspace_size != nullptr)

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -792,25 +792,37 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
792792
TLLM_CHECK_WITH_INFO(
793793
inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization");
794794
// EpilogueTag is ignored
795+
#define SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(SCALE_FACTOR) \
796+
if (hopper_inputs.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) \
797+
{ \
798+
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
799+
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, \
800+
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
801+
} \
802+
else \
803+
{ \
804+
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
805+
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, \
806+
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
807+
}
808+
795809
if (inputs.k % 512 == 0)
796810
{
797-
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
798-
cutlass_extensions::EpilogueOpDefault, 4>(inputs, hopper_inputs, multi_processor_count_, nullptr);
811+
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(4)
799812
}
800813
else if (inputs.k % 256 == 0)
801814
{
802-
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
803-
cutlass_extensions::EpilogueOpDefault, 2>(inputs, hopper_inputs, multi_processor_count_, nullptr);
815+
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(2)
804816
}
805817
else if (inputs.k % 128 == 0)
806818
{
807-
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
808-
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
819+
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(1)
809820
}
810821
else
811822
{
812823
TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k);
813824
}
825+
#undef SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE
814826
return;
815827
}
816828

@@ -820,7 +832,8 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
820832
inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization");
821833
// EpilogueTag is ignored
822834
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
823-
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
835+
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, 1>(
836+
inputs, hopper_inputs, multi_processor_count_, nullptr);
824837
return;
825838
}
826839
#endif

0 commit comments

Comments
 (0)