4545#include " cutlass/util/tensor_view_io.h"
4646
4747#include " cutlass_extensions/compute_occupancy.h"
48+ #include " cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
4849#include " cutlass_extensions/epilogue_helpers.h"
4950#include " cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
5051#include " cutlass_extensions/gemm_configs.h"
@@ -71,11 +72,12 @@ namespace cutlass_kernels_oss
7172using namespace tensorrt_llm ::kernels::cutlass_kernels;
7273namespace tk = tensorrt_llm::common;
7374namespace tkc = tensorrt_llm::cutlass_extensions;
75+ using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
7476
7577using namespace cute ;
7678
77- template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape ,
78- typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
79+ template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION ,
80+ typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
7981 cutlass::WeightOnlyQuantOp QuantOp>
8082void sm90_generic_mixed_moe_gemm_kernelLauncher (GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
8183 TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t * workspace_size)
@@ -85,6 +87,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
8587 // ///////////////////////////////////////////////////////////////////////////////////////////////
8688 // / GEMM kernel configurations
8789 // ///////////////////////////////////////////////////////////////////////////////////////////////
90+ static_assert (FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE,
91+ " Unimplemented fusion provided to TMA WS Mixed MoE gemm launcher" );
92+ constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE;
8893
8994 // A matrix configuration
9095 using ElementA = typename TllmToCutlassTypeAdapter<T>::type;
@@ -129,13 +134,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
129134 using ElementD = ElementC;
130135 using LayoutD = LayoutC;
131136 constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
137+ using ElementFinalOutput = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
138+ using ElementBias = ElementFinalOutput;
139+ using ElementRouterScales = float ;
132140
133141 // Core kernel configurations
134142 using ElementAccumulator = float ; // Element type for internal accumulation
135143 using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
136144 using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
137145 using TileShape = CTAShape; // Threadblock-level tile size
138146 using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
147+
148+ using EpilogueFusionOp = cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter<
149+ typename cutlass::layout::LayoutTranspose<LayoutD>::type, ElementFinalOutput, ElementAccumulator, ElementBias,
150+ ElementRouterScales>;
151+
139152 using KernelSchedule
140153 = std::conditional_t <std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecializedPingpong>,
141154 cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
@@ -145,12 +158,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
145158 cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
146159 cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>; // Epilogue to launch
147160
148- using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
161+ using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
162+ cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
163+ ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
164+ AlignmentC, void , typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD, EpilogueSchedule,
165+ EpilogueFusionOp>::CollectiveOp;
166+
167+ using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
149168 cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
150169 ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
151170 AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD,
152171 EpilogueSchedule>::CollectiveOp;
153172
173+ using CollectiveEpilogue
174+ = std::conditional_t <IsFinalizeFusion, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
175+
154176 // =========================================================== MIXED INPUT WITH SCALES
155177 // =========================================================================== The Scale information must get paired
156178 // with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the
@@ -175,20 +197,56 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
175197 Args arguments;
176198
177199 decltype (arguments.epilogue .thread ) fusion_args;
178- fusion_args.alpha = use_wfp4a16 ? 1 : 0 ;
179- fusion_args.beta = 0 ;
180- fusion_args.alpha_ptr = nullptr ;
181- fusion_args.beta_ptr = nullptr ;
182- fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales ;
183- fusion_args.beta_ptr_array = nullptr ;
184- // One alpha and beta per each group
185- fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1 };
186- fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1 };
187200
188201 cutlass::KernelHardwareInfo hw_info;
189202 hw_info.device_id = 0 ;
190203 hw_info.sm_count = sm_count_;
191204
205+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
206+ using EpilogueScalars = decltype (EpilogueArguments{}.thread );
207+ EpilogueScalars epilogue_scalars = [&]
208+ {
209+ if constexpr (IsFinalizeFusion)
210+ {
211+ auto epi_params = hopper_inputs.fused_finalize_epilogue ;
212+ return EpilogueScalars{ElementAccumulator (1 ), nullptr , hopper_inputs.alpha_scale_ptr_array ,
213+ Stride<_0, _0, int64_t >{cute::_0{}, cute::_0{}, 1 }, /* alpha */
214+ reinterpret_cast <ElementBias const * const *>(epi_params.ptr_bias ), Stride<_1, _0, int64_t >{}, /* bias */
215+ epi_params.ptr_router_scales , Stride<_0, _1, int64_t >{}, /* scale */
216+ reinterpret_cast <ElementFinalOutput*>(epi_params.ptr_final_output ),
217+ epi_params.stride_final_output_transposed , epi_params.ptr_source_token_index ,
218+ epi_params.num_rows_in_final_output , epi_params.shape_override , epi_params.use_reduction };
219+ }
220+ else
221+ {
222+ return EpilogueScalars{};
223+ }
224+ }();
225+
226+ EpilogueArguments epilogue_args = [&]
227+ {
228+ if constexpr (IsFinalizeFusion)
229+ {
230+ return EpilogueArguments{epilogue_scalars, nullptr , nullptr , nullptr , nullptr };
231+ }
232+ else
233+ {
234+ fusion_args.alpha = use_wfp4a16 ? 1 : 0 ;
235+ fusion_args.beta = 0 ;
236+ fusion_args.alpha_ptr = nullptr ;
237+ fusion_args.beta_ptr = nullptr ;
238+ fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales ;
239+ fusion_args.beta_ptr_array = nullptr ;
240+ // One alpha and beta per each group
241+ fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1 };
242+ fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1 };
243+
244+ return EpilogueArguments{fusion_args, reinterpret_cast <ElementC const **>(hopper_inputs.ptr_c ),
245+ reinterpret_cast <StrideC*>(hopper_inputs.stride_c ), reinterpret_cast <ElementD**>(hopper_inputs.ptr_d ),
246+ reinterpret_cast <StrideD*>(hopper_inputs.stride_d )};
247+ }
248+ }();
249+
192250 arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped ,
193251 {inputs.num_experts , hopper_inputs.int4_groupwise_params .shape .problem_shapes , nullptr },
194252 {reinterpret_cast <ElementB const **>(hopper_inputs.ptr_weight ),
@@ -197,10 +255,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
197255 reinterpret_cast <StrideA*>(hopper_inputs.stride_act ),
198256 reinterpret_cast <ElementScalePacked const **>(hopper_inputs.int4_groupwise_params .ptr_s_a ),
199257 reinterpret_cast <StrideS*>(hopper_inputs.int4_groupwise_params .stride_s_a ), group_size},
200- {fusion_args, reinterpret_cast <ElementC const **>(hopper_inputs.ptr_c ),
201- reinterpret_cast <StrideC*>(hopper_inputs.stride_c ), reinterpret_cast <ElementD**>(hopper_inputs.ptr_d ),
202- reinterpret_cast <StrideD*>(hopper_inputs.stride_d )},
203- hw_info};
258+ epilogue_args, hw_info};
204259
205260 assert (group_size == int (inputs.groupwise_quant_group_size ));
206261 if (workspace_size != nullptr )
0 commit comments