11#pragma once
22
3+ #include " cuda_utils.h"
34#include " cutlass/cutlass.h"
45#include " cutlass/numeric_types.h"
56
@@ -22,49 +23,49 @@ namespace vllm {
2223
2324using namespace cute ;
2425
25- template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
26- class ClusterShape , typename EpilogueScheduler,
27- typename MainloopScheduler>
26+ // clang-format off
27+ template <class OutType , int ScaleGranularityM,
28+ int ScaleGranularityN, int ScaleGranularityK,
29+ class MmaTileShape , class ClusterShape ,
30+ class EpilogueScheduler , class MainloopScheduler ,
31+ bool swap_ab_ = false >
2832struct cutlass_3x_gemm_fp8_blockwise {
33+ static constexpr bool swap_ab = swap_ab_;
2934 using ElementAB = cutlass::float_e4m3_t ;
3035
3136 using ElementA = ElementAB;
3237 using LayoutA = cutlass::layout::RowMajor;
38+ using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
3339 static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
3440
3541 using ElementB = ElementAB;
3642 using LayoutB = cutlass::layout::ColumnMajor;
43+ using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
3744 static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
3845
39- using ElementC = void ;
4046 using ElementD = OutType;
4147 using LayoutD = cutlass::layout::RowMajor;
48+ using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
4249 static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
4350
51+ using ElementC = void ; // TODO: support bias
4452 using LayoutC = LayoutD;
53+ using LayoutC_Transpose = LayoutD_Transpose;
4554 static constexpr int AlignmentC = AlignmentD;
4655
4756 using ElementAccumulator = float ;
4857 using ElementCompute = float ;
4958 using ElementBlockScale = float ;
5059
51- // MMA and Cluster Tile Shapes
52- // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
53- // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
54- static constexpr int ScaleMsPerTile = size<0 >(ScalesPerTile{});
55- static constexpr int ScaleGranularityM =
56- size<0 >(MmaTileShape{}) / ScaleMsPerTile;
57- static constexpr int ScaleGranularityN =
58- size<1 >(MmaTileShape{}) / size<1 >(ScalesPerTile{});
59- static constexpr int ScaleGranularityK =
60- size<2 >(MmaTileShape{}) / size<2 >(ScalesPerTile{});
61-
62- // Shape of the threadblocks in a cluster
63- using ClusterShape_MNK = ClusterShape;
64-
65- using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
66- ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
67- cute::UMMA::Major::MN, cute::UMMA::Major::K>;
60+ using ScaleConfig = conditional_t <swap_ab,
61+ cutlass::detail::Sm100BlockwiseScaleConfig<
62+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
63+ cute::UMMA::Major::K, cute::UMMA::Major::MN>,
64+ cutlass::detail::Sm100BlockwiseScaleConfig<
65+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
66+ cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
67+
68+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
6869 using LayoutSFA = decltype (ScaleConfig::deduce_layoutSFA());
6970 using LayoutSFB = decltype (ScaleConfig::deduce_layoutSFB());
7071
@@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
7374
7475 static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
7576 using ElementScalar = float ;
76- // clang-format off
7777 using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
7878 using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
7979 ArchTag,
@@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
8484 ElementAccumulator,
8585 ElementCompute,
8686 ElementC,
87- LayoutC,
87+ conditional_t <swap_ab, LayoutC_Transpose, LayoutC> ,
8888 AlignmentC,
8989 ElementD,
90- LayoutD,
90+ conditional_t <swap_ab, LayoutD_Transpose, LayoutD> ,
9191 AlignmentD,
9292 EpilogueScheduler,
9393 DefaultOperation
9494 >::CollectiveOp;
9595
9696 using StageCountType = cutlass::gemm::collective::StageCountAuto;
97- using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
98- ArchTag,
99- OperatorClass ,
100- ElementA ,
101- cute::tuple<LayoutA, LayoutSFA> ,
102- AlignmentA ,
103- ElementB ,
104- cute::tuple<LayoutB, LayoutSFB> ,
105- AlignmentB ,
106- ElementAccumulator ,
107- MmaTileShape ,
108- ClusterShape ,
109-
97+ using CollectiveMainloop = conditional_t <swap_ab,
98+ typename cutlass::gemm::collective::CollectiveBuilder<
99+ ArchTag ,
100+ OperatorClass ,
101+ ElementB ,
102+ cute::tuple<LayoutB_Transpose, LayoutSFA> ,
103+ AlignmentB ,
104+ ElementA ,
105+ cute::tuple<LayoutA_Transpose, LayoutSFB> ,
106+ AlignmentA ,
107+ ElementAccumulator ,
108+ MmaTileShape ,
109+ ClusterShape,
110110 cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(sizeof (typename CollectiveEpilogue::SharedStorage))>,
111- MainloopScheduler
112- >::CollectiveOp;
113- // clang-format on
111+ MainloopScheduler
112+ >::CollectiveOp,
113+ typename cutlass::gemm::collective::CollectiveBuilder<
114+ ArchTag,
115+ OperatorClass,
116+ ElementA,
117+ cute::tuple<LayoutA, LayoutSFA>,
118+ AlignmentA,
119+ ElementB,
120+ cute::tuple<LayoutB, LayoutSFB>,
121+ AlignmentB,
122+ ElementAccumulator,
123+ MmaTileShape,
124+ ClusterShape,
125+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(sizeof (typename CollectiveEpilogue::SharedStorage))>,
126+ MainloopScheduler
127+ >::CollectiveOp>;
114128
115129 using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
116130 Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue>>;
@@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
123137 torch::Tensor const & b,
124138 torch::Tensor const & a_scales,
125139 torch::Tensor const & b_scales) {
140+ static constexpr bool swap_ab = Gemm::swap_ab;
126141 using GemmKernel = typename Gemm::GemmKernel;
127142 using StrideA = typename Gemm::GemmKernel::StrideA;
128143 using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
136151 using ElementD = typename Gemm::ElementD;
137152
138153 int32_t m = a.size (0 ), n = b.size (1 ), k = a.size (1 );
139- auto prob_shape = cute::make_shape (m, n, k, 1 );
140154
141155 StrideA a_stride;
142156 StrideB b_stride;
@@ -146,21 +160,36 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146160 b_stride =
147161 cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, 1 ));
148162 c_stride =
149- cutlass::make_cute_packed_stride (StrideC{}, cute::make_shape (m, n, 1 ));
163+ cutlass::make_cute_packed_stride (StrideC{}, swap_ab ? cute::make_shape (n, m, 1 ) : cute::make_shape (m, n, 1 ));
150164
151- LayoutSFA layout_SFA =
165+ LayoutSFA layout_SFA = swap_ab ?
166+ ScaleConfig::tile_atom_to_shape_SFA (make_shape (n, m, k, 1 )) :
152167 ScaleConfig::tile_atom_to_shape_SFA (make_shape (m, n, k, 1 ));
153- LayoutSFB layout_SFB =
168+ LayoutSFB layout_SFB = swap_ab ?
169+ ScaleConfig::tile_atom_to_shape_SFB (make_shape (n, m, k, 1 )) :
154170 ScaleConfig::tile_atom_to_shape_SFB (make_shape (m, n, k, 1 ));
155171
156172 auto a_ptr = static_cast <ElementAB*>(a.data_ptr ());
157173 auto b_ptr = static_cast <ElementAB*>(b.data_ptr ());
158174 auto a_scales_ptr = static_cast <float *>(a_scales.data_ptr ());
159175 auto b_scales_ptr = static_cast <float *>(b_scales.data_ptr ());
160176
161- typename GemmKernel::MainloopArguments mainloop_args{
162- a_ptr, a_stride, b_ptr, b_stride,
163- a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
177+ auto mainloop_args = [&](){
178+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
179+ if (swap_ab) {
180+ return typename GemmKernel::MainloopArguments{
181+ b_ptr, b_stride, a_ptr, a_stride,
182+ b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
183+ };
184+ }
185+ else {
186+ return typename GemmKernel::MainloopArguments{
187+ a_ptr, a_stride, b_ptr, b_stride,
188+ a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
189+ };
190+ }
191+ }();
192+ auto prob_shape = swap_ab ? cute::make_shape (n, m, k, 1 ) : cute::make_shape (m, n, k, 1 );
164193
165194 auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
166195 typename GemmKernel::EpilogueArguments epilogue_args{
@@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
175204 torch::Tensor const & b,
176205 torch::Tensor const & a_scales,
177206 torch::Tensor const & b_scales) {
178- auto m = a.size (0 );
179- auto k = a.size (1 );
180- auto n = b.size (1 );
181- int sms;
207+ int32_t m = a.size (0 ), n = b.size (1 ), k = a.size (1 ), sms;
182208 cudaDeviceGetAttribute (&sms, cudaDevAttrMultiProcessorCount, a.get_device ());
183209
184- auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128 ) {
185- return std::ceil (static_cast <float >(m) / tile1SM) *
186- std::ceil (static_cast <float >(n) / tile1SM) >=
187- sms;
188- };
189- bool use_2sm = should_use_2sm (m, n);
190- if (use_2sm) {
191- cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
192- OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
193- Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
194- cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
195- out, a, b, a_scales, b_scales);
210+ constexpr int TILE_K = 128 ;
211+ // TODO: better heuristics
212+ bool swap_ab = (m < 16 ) || (m % 4 != 0 );
213+ bool use_tma_epilogue = (m * n) % 4 == 0 ;
214+ if (!swap_ab) {
215+ constexpr int TILE_N = 128 ;
216+ int tile_m = 256 ;
217+ if (cuda_utils::ceil_div (n, TILE_N) * cuda_utils::ceil_div (m, 64 ) <= sms) {
218+ tile_m = 64 ;
219+ }
220+ else if (cuda_utils::ceil_div (n, TILE_N) * cuda_utils::ceil_div (m, 128 ) <= sms) {
221+ tile_m = 128 ;
222+ }
223+ if (tile_m == 64 ) {
224+ if (use_tma_epilogue) {
225+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
226+ OutType, 1 , TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
227+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
228+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
229+ out, a, b, a_scales, b_scales);
230+ } else {
231+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
232+ OutType, 1 , TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
233+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
234+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
235+ out, a, b, a_scales, b_scales);
236+ }
237+ } else if (tile_m == 128 ) {
238+ if (use_tma_epilogue) {
239+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
240+ OutType, 1 , TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
241+ Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
242+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
243+ out, a, b, a_scales, b_scales);
244+ } else {
245+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
246+ OutType, 1 , TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
247+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
248+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
249+ out, a, b, a_scales, b_scales);
250+ }
251+ } else { // tile_m == 256
252+ if (use_tma_epilogue) {
253+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
254+ OutType, 1 , TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
255+ Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
256+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
257+ out, a, b, a_scales, b_scales);
258+ } else {
259+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
260+ OutType, 1 , TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
261+ Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
262+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
263+ out, a, b, a_scales, b_scales);
264+ }
265+ }
196266 } else {
267+ // TODO: Test more tile N configs
268+ constexpr int TILE_M = 128 ;
269+ constexpr int TILE_N = 16 ;
270+ // TMA epilogue isn't compatible with Swap A/B
197271 cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
198- OutType, Shape<_128, _128, _128> , Shape<_128, _1, _1 >,
199- Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm ,
200- cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
272+ OutType, TILE_M, 1 , TILE_K , Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K> >,
273+ Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm ,
274+ cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true >>(
201275 out, a, b, a_scales, b_scales);
202276 }
203277}
0 commit comments