Skip to content

Commit dd9a333

Browse files
authored
feat: [GPT-OSS] Add MXFP8 x MXFP4 CUTLASS MOE for SM100 and BF16 x MXFP4 CUTLASS for SM90 + SwigluBias Activation (#1396)
## 📌 Description This PR adds MXFP8 x MXFP4 CUTLASS MOE with SwigluBias for SM100 GPUs. It also adds BF16 x MXFP4 CUTLASS MOE with SwigluBias for SM 90 ## 🔍 Related Issues N/A ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 4c1ac5f commit dd9a333

File tree

79 files changed

+9268
-990
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+9268
-990
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>;
4545
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>;
4646
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>;
4747
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>;
48+
template class CutlassMoeFCRunner<half, __nv_fp4_e2m1>;
4849
#ifdef ENABLE_BF16
4950
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>;
5051
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>;
5152
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>;
5253
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>;
54+
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>;
5355
#endif
5456
#endif
5557
}; // namespace tensorrt_llm::kernels::cutlass_kernels

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 266 additions & 101 deletions
Large diffs are not rendered by default.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu

Lines changed: 135 additions & 36 deletions
Large diffs are not rendered by default.

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 59 additions & 140 deletions
Large diffs are not rendered by default.

csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp

Lines changed: 599 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
/*
2+
* Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include "cutlass/arch/mma.h"
21+
#include "cutlass/gemm/collective/builders/sm90_common.inl"
22+
#include "cutlass/gemm/dispatch_policy.hpp"
23+
#include "cutlass/gemm/gemm.h"
24+
25+
// SM90 Collective Builders should be used only starting CUDA 12.0
26+
#if (__CUDACC_VER_MAJOR__ >= 12)
27+
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
28+
#endif
29+
30+
/////////////////////////////////////////////////////////////////////////////////////////////////
31+
32+
namespace cutlass::gemm::collective {
33+
34+
/////////////////////////////////////////////////////////////////////////////////////////////////
35+
36+
namespace detail {
37+
38+
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or
39+
// overrides with manual count.
40+
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB,
41+
int carveout_bytes>
42+
constexpr int compute_stage_count_or_override_gated(
43+
StageCountAutoCarveout<carveout_bytes> stage_count) {
44+
// 32 bytes to account for barriers etc.
45+
constexpr int stage_barrier_bytes = 32;
46+
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
47+
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
48+
constexpr int stage_bytes = [&]() -> int {
49+
if constexpr (SwapAB) {
50+
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 +
51+
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
52+
} else {
53+
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
54+
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 +
55+
stage_barrier_bytes;
56+
}
57+
}();
58+
59+
return (CapacityBytes - carveout_bytes) / stage_bytes;
60+
}
61+
62+
} // namespace detail
63+
64+
/////////////////////////////////////////////////////////////////////////////////////////////////
65+
66+
// GMMA_TMA_WS_SS
67+
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
68+
int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
69+
class StageCountType, class KernelScheduleType,
70+
template <class /* ElementCompute */> class Activation, bool SwapAB>
71+
struct CollectiveBuilderGated<
72+
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB, GmemLayoutB,
73+
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
74+
KernelScheduleType, Activation, SwapAB,
75+
cute::enable_if_t<
76+
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
77+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
78+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
79+
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
80+
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>> {
81+
static_assert(is_static<TileShape_MNK>::value);
82+
static_assert(is_static<ClusterShape_MNK>::value);
83+
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
84+
static_assert(cutlass::detail::dependent_false<ElementA>,
85+
"Unsupported Toolkit for SM90 Collective Builder\n");
86+
#endif
87+
static_assert(
88+
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
89+
"Should meet TMA alignment requirement\n");
90+
91+
static constexpr bool IsArrayOfPointersGemm =
92+
(cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
93+
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
94+
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
95+
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 "
96+
"FastAccum version right now\n");
97+
98+
// For fp32 types, map to tf32 MMA value type
99+
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
100+
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
101+
102+
static constexpr cute::GMMA::Major GmmaMajorA =
103+
detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
104+
static constexpr cute::GMMA::Major GmmaMajorB =
105+
detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
106+
107+
using AtomLayoutMNK = cute::conditional_t<
108+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
109+
IsArrayOfPointersGemm,
110+
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
111+
112+
using TiledMma = decltype(cute::make_tiled_mma(
113+
cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK,
114+
GmmaMajorA, GmmaMajorB>(),
115+
AtomLayoutMNK{}));
116+
117+
using GmemTiledCopyA =
118+
decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
119+
using GmemTiledCopyB =
120+
decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
121+
122+
using SmemLayoutAtomA =
123+
decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
124+
decltype(cute::get<0>(TileShape_MNK{})),
125+
decltype(cute::get<2>(TileShape_MNK{}))>());
126+
using SmemLayoutAtomB =
127+
decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
128+
decltype(cute::get<1>(TileShape_MNK{})),
129+
decltype(cute::get<2>(TileShape_MNK{}))>());
130+
131+
static constexpr int PipelineStages =
132+
detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA,
133+
MmaElementB, TileShape_MNK, SwapAB>(
134+
StageCountType{});
135+
using DispatchPolicy = cute::conditional_t<
136+
IsArrayOfPointersGemm,
137+
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
138+
/* For FP8 use a separate mainloop compared to other datatypes */
139+
cute::conditional_t<IsFP8Input,
140+
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK,
141+
KernelScheduleType>,
142+
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
143+
KernelScheduleType>>>;
144+
145+
using SmemCopyAtomA = void;
146+
using SmemCopyAtomB = void;
147+
148+
using CollectiveOp =
149+
CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
150+
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
151+
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
152+
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
153+
};
154+
155+
/////////////////////////////////////////////////////////////////////////////////////////////////
156+
157+
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
158+
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
159+
int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
160+
class StageCountType, class KernelScheduleType,
161+
template <class /* ElementCompute */> class Activation, bool SwapAB>
162+
struct CollectiveBuilderGated<
163+
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB, GmemLayoutB,
164+
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
165+
KernelScheduleType, Activation, SwapAB,
166+
cute::enable_if_t<
167+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
168+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
169+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
170+
cute::is_same_v<KernelScheduleType,
171+
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
172+
static_assert(is_static<TileShape_MNK>::value);
173+
static_assert(is_static<ClusterShape_MNK>::value);
174+
static_assert(
175+
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
176+
"Not meet TMA alignment requirement yet\n");
177+
static_assert(detail::is_input_fp8<ElementA, ElementB>(),
178+
"Only FP8 datatypes are compatible with these kernel schedules\n");
179+
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
180+
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
181+
"Not supported for fp8 non-TN warp specialized kernels yet\n");
182+
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
183+
static_assert(cutlass::detail::dependent_false<ElementA>,
184+
"Unsupported Toolkit for SM90 Collective Builder\n");
185+
#endif
186+
187+
static constexpr cute::GMMA::Major GmmaMajorA =
188+
detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
189+
static constexpr cute::GMMA::Major GmmaMajorB =
190+
detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
191+
192+
static constexpr bool IsArrayOfPointersGemm =
193+
(cute::is_same_v<KernelScheduleType,
194+
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
195+
using AtomLayoutMNK = cute::conditional_t<
196+
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
197+
IsArrayOfPointersGemm,
198+
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
199+
200+
using TiledMma = decltype(cute::make_tiled_mma(
201+
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA,
202+
GmmaMajorB>(),
203+
AtomLayoutMNK{}));
204+
205+
using GmemTiledCopyA =
206+
decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
207+
using GmemTiledCopyB =
208+
decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
209+
210+
using SmemLayoutAtomA =
211+
decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
212+
decltype(cute::get<0>(TileShape_MNK{})),
213+
decltype(cute::get<2>(TileShape_MNK{}))>());
214+
using SmemLayoutAtomB =
215+
decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
216+
decltype(cute::get<1>(TileShape_MNK{})),
217+
decltype(cute::get<2>(TileShape_MNK{}))>());
218+
219+
static constexpr int PipelineStages =
220+
detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA,
221+
ElementB, TileShape_MNK, SwapAB>(
222+
StageCountType{});
223+
using DispatchPolicy = cute::conditional_t<
224+
IsArrayOfPointersGemm,
225+
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
226+
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
227+
228+
using SmemCopyAtomA = void;
229+
using SmemCopyAtomB = void;
230+
231+
using CollectiveOp =
232+
CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
233+
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
234+
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
235+
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
236+
};
237+
238+
/////////////////////////////////////////////////////////////////////////////////////////////////
239+
240+
/////////////////////////////////////////////////////////////////////////////////////////////////
241+
242+
} // namespace cutlass::gemm::collective
243+
244+
/////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)