|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "cute/algorithm/copy.hpp" |
| 16 | +#include "cute/atom/mma_atom.hpp" |
| 17 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 18 | + |
| 19 | +#include "cutlass/cutlass.h" |
| 20 | +#include "cutlass/layout/layout.h" |
| 21 | +#include "cutlass/numeric_types.h" |
| 22 | +#include "cutlass/pipeline/pipeline.hpp" |
| 23 | + |
| 24 | +using namespace cute; |
| 25 | + |
| 26 | +template <int kStages, class GemmType, class OutputType, class SmemLayoutA, |
| 27 | + class SmemLayoutB, class SmemLayoutC> |
| 28 | +struct SharedStorage { |
| 29 | + union { |
| 30 | + struct { |
| 31 | + cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a; |
| 32 | + cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b; |
| 33 | + }; |
| 34 | + cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c; |
| 35 | + }; |
| 36 | + |
| 37 | + struct { |
| 38 | + typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline; |
| 39 | + }; |
| 40 | +}; |
| 41 | + |
| 42 | +template<int kBlockM_, int kBlockN_, int kBlockK_, |
| 43 | + int kNWarps_, int kStages_, |
| 44 | + int kTiles_, int M_, |
| 45 | + int TokenPackSize_, |
| 46 | + int TAIL_N_ = 0, |
| 47 | + int kClusterM_ = 1, |
| 48 | + typename elem_type=cutlass::float_e4m3_t, |
| 49 | + typename OutputType = cutlass::bfloat16_t> |
| 50 | +struct Kernel_traits { |
| 51 | + using Element = elem_type; |
| 52 | + using ElementAccum = float; |
| 53 | + using ElementOutput = OutputType; |
| 54 | + static_assert(cutlass::sizeof_bits_v<Element> == 8); |
| 55 | + |
| 56 | + static constexpr int kNWarps = kNWarps_; |
| 57 | + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; |
| 58 | + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; |
| 59 | + static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; |
| 60 | + |
| 61 | + static_assert(kNWarps_ == 12 || kNWarps_ == 16); |
| 62 | + |
| 63 | + static constexpr int kBlockM = kBlockM_; |
| 64 | + static constexpr int kBlockN = kBlockN_; |
| 65 | + static constexpr int kBlockK = kBlockK_; |
| 66 | + static constexpr int kTiles = kTiles_; |
| 67 | + static constexpr int TokenPackSize = TokenPackSize_; |
| 68 | + static constexpr int M = M_; |
| 69 | + static constexpr int TAIL_N = TAIL_N_; |
| 70 | + |
| 71 | + using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>; |
| 72 | + using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>; |
| 73 | + |
| 74 | + static constexpr int kClusterM = kClusterM_; |
| 75 | + using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>; |
| 76 | + |
| 77 | + static constexpr int kStages = kStages_; |
| 78 | + static_assert(kStages > 1); |
| 79 | + |
| 80 | + using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>; |
| 81 | + |
| 82 | + using TiledMma = decltype(cute::make_tiled_mma( |
| 83 | + cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(), |
| 84 | + AtomLayoutMNK{})); |
| 85 | + |
| 86 | + using TiledMma_TAIL = decltype(cute::make_tiled_mma( |
| 87 | + cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(), |
| 88 | + AtomLayoutMNK{})); |
| 89 | + |
| 90 | + using SmemLayoutAtomA = decltype( |
| 91 | + cutlass::gemm::collective::detail::rs_smem_selector< |
| 92 | + GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>()); |
| 93 | + |
| 94 | + using SmemLayoutA = decltype( |
| 95 | + tile_to_shape(SmemLayoutAtomA{}, |
| 96 | + make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{}))); |
| 97 | + |
| 98 | + using SmemLayoutAtomB = decltype( |
| 99 | + cutlass::gemm::collective::detail::rs_smem_selector< |
| 100 | + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), |
| 101 | + decltype(cute::get<2>(TileShape_MNK{}))>()); |
| 102 | + |
| 103 | + using SmemLayoutB = decltype( |
| 104 | + tile_to_shape(SmemLayoutAtomB{}, |
| 105 | + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{}))); |
| 106 | + |
| 107 | + using SmemLayoutAtomB_TAIL = decltype( |
| 108 | + cutlass::gemm::collective::detail::rs_smem_selector< |
| 109 | + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})), |
| 110 | + decltype(cute::get<2>(TileShape_MNK_TAIL{}))>()); |
| 111 | + |
| 112 | + using SmemLayoutB_TAIL = decltype( |
| 113 | + tile_to_shape(SmemLayoutAtomB_TAIL{}, |
| 114 | + make_shape( |
| 115 | + shape<1>(TileShape_MNK_TAIL{}), |
| 116 | + shape<2>(TileShape_MNK_TAIL{}), |
| 117 | + Int<kStages>{}) |
| 118 | + )); |
| 119 | + |
| 120 | + using SmemLayoutAtomC = decltype( |
| 121 | + cutlass::gemm::collective::detail::rs_smem_selector< |
| 122 | + GMMA::Major::K, ElementOutput, |
| 123 | + decltype(cute::get<0>(TileShape_MNK{})), |
| 124 | + decltype(cute::get<1>(TileShape_MNK{}))>()); |
| 125 | + |
| 126 | + using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{}))); |
| 127 | + |
| 128 | + using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>; |
| 129 | + using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>; |
| 130 | + |
| 131 | + using SharedStorage = SharedStorage< |
| 132 | + kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>; |
| 133 | + |
| 134 | + using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>; |
| 135 | + using PipelineState = typename cutlass::PipelineState<kStages>; |
| 136 | + |
| 137 | + |
| 138 | + static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>); |
| 139 | + static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem; |
| 140 | + // static_assert(NumMmaThreads % kNumThreadsPerRow == 0); |
| 141 | + static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; |
| 142 | + using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>; |
| 143 | + using TiledCopyCThrLayout = decltype(cute::make_layout( |
| 144 | + cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}), |
| 145 | + LayoutRight{})); |
| 146 | + using TiledCopyCValLayout = decltype(cute::make_layout( |
| 147 | + cute::make_shape(_1{}, Int<kNumVecElem>{}), |
| 148 | + LayoutRight{})); |
| 149 | + using TiledCopyC = decltype(make_tiled_copy( |
| 150 | + TiledCopyCAtom{}, |
| 151 | + TiledCopyCThrLayout{}, // Thr layout |
| 152 | + TiledCopyCValLayout{} // Val layout |
| 153 | + )); |
| 154 | +}; |
0 commit comments