diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index 3918e66451..d7f5178785 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -350,6 +350,10 @@ int main(int argc, const char** argv) // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; + using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>; + using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>; + + // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -369,9 +373,8 @@ int main(int argc, const char** argv) // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. constexpr int PipelineStages = 2; - // For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeL1Staged; // This is the 'default' epilogue operation (Linear Combination) which performs everything in: // (D = alpha * (A*B) + beta * C) @@ -394,9 +397,9 @@ int main(int argc, const char** argv) ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + GmemTiledCopyC, // The copy atom used to load matrix C void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + GmemTiledCopyD, // The copy atom used to store matrix D void, void>; // GEMM Mainloop - iteration over blocks in K dimension diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 60876df697..c180614dd9 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -1143,12 +1143,22 @@ template auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor) { if constexpr (!std::is_void_v) { - return make_block_2d_copy_C(CopyOp{}, tiled_mma, c_tensor); + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor); } else { return make_block_2d_copy_C(tiled_mma, c_tensor); } } +template +auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor); + } else { + return make_block_2d_copy_D(tiled_mma, d_tensor); + } +} + // // Display utilities // diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 2e7fbeb579..4f6d041363 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -71,6 +71,7 @@ class CollectiveEpilogue { #include "sm100_epilogue_array_tma_warpspecialized.hpp" #if defined (SYCL_INTEL_TARGET) #include "xe_epilogue.hpp" +#include "xe_epilogue_legacy.hpp" #include "xe_array_epilogue.hpp" #endif // diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 05437ab180..0a5abe7b78 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -68,7 +68,7 @@ template < class CopyOpR2S_ > class CollectiveEpilogue< - IntelXeXMX16, + IntelXeL1Staged, CtaTileMNK_, ElementC_, StrideC_, @@ -86,7 +86,7 @@ class CollectiveEpilogue< // // Type Aliases // - using DispatchPolicy = IntelXeXMX16; + using DispatchPolicy = IntelXeL1Staged; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; @@ -102,9 +102,6 @@ class CollectiveEpilogue< using CopyOpR2S = CopyOpR2S_; using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; - using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, - CopyOpR2G, XE_2D_U32x8x16_ST_N>; using ElementOutput = ElementD; using ElementCompute = ElementAccumulator; @@ -119,19 +116,10 @@ class CollectiveEpilogue< static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - using CopyThreadShape = Shape<_1, Int>; - - using Trait_C = Copy_Traits; - using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_load_C{})); - - using Trait_D = Copy_Traits; - using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); - +//remember this PR https://github.com/intel/sycl-tla/pull/565/files private: - constexpr static bool is_source_supported = not cute::is_void_v && not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -154,6 +142,15 @@ class CollectiveEpilogue< }; using TensorStorage = typename SharedStorage::TensorStorage; + // Helper to get tensor types + template + using TensorTypeC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_layout(make_shape(int{}, int{}, int{}), Stride{}))); + + template + using TensorTypeD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_layout(make_shape(int{}, int{}, int{}), Stride{}))); + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; @@ -166,8 +163,8 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { typename FusionCallbacks::Params thread{}; - XE_Copy_C xe_load_c; - XE_Copy_D xe_store_d; + TensorTypeC mC; + TensorTypeD mD; }; // @@ -183,23 +180,13 @@ class CollectiveEpilogue< // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - - XE_Copy_C xe_load_c = {}; - if constexpr (is_source_supported) { - auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); - xe_load_c = {xe_load_c.with(mC)}; - } - - XE_Copy_D xe_store_d = {}; - if constexpr (is_destination_supported) { - auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); - xe_store_d = {xe_store_d.with(mD)}; - } + auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - xe_load_c, - xe_store_d, + mC, + mD }; } @@ -270,6 +257,37 @@ class CollectiveEpilogue< return fusion_callbacks.is_producer_load_needed(); } + template + CUTLASS_DEVICE auto reshape_with_unit_insertion(Tensor&& tensor) { + using namespace cute; + + auto orig_layout = tensor.layout(); + auto orig_shape = orig_layout.shape(); + auto orig_stride = orig_layout.stride(); + + auto first_dim = get<0>(orig_shape); + auto outer_part = get<0>(first_dim); + auto inner_part = get<1>(first_dim); + + auto first_stride = get<0>(orig_stride); + auto outer_stride = get<0>(first_stride); + auto inner_stride = get<1>(first_stride); + + auto target_shape = make_shape( + make_shape(outer_part, _1{}), + get<0>(inner_part), + get<1>(inner_part) + ); + + auto target_stride = make_stride( + make_stride(outer_stride, _0{}), + get<0>(inner_stride), + get<1>(inner_stride) + ); + + return make_tensor(tensor.data(), make_layout(target_shape, target_stride)); +} + template< class ProblemShapeMNKL, class TileShapeMNK, @@ -286,7 +304,6 @@ class CollectiveEpilogue< TiledMma tiled_mma, int thread_idx) { - (void) tiled_mma; using namespace cute; static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); @@ -297,12 +314,11 @@ class CollectiveEpilogue< static constexpr auto BLK_M = get<0>(CtaTileMNK{}); static constexpr auto BLK_N = get<1>(CtaTileMNK{}); static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, "assertation fail"); static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static_assert( + + static_assert( BLK_M % ATOM_M == 0 && BLK_N % ATOM_N == 0 && BLK_K % ATOM_K == 0, @@ -316,46 +332,46 @@ class CollectiveEpilogue< static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; - + // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; auto m_sg = get_sub_group_id() / ATOM_N; auto n_sg = get_sub_group_id() % ATOM_N; - auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); - auto sg_local_m_coord = get_sub_group_id() / ATOM_N; auto sg_local_n_coord = get_sub_group_id() % ATOM_N; auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); - + + auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + auto batch_idx = get<3>(wg_coord); + auto copy_c = get_block_2d_copy_C(tiled_mma, params.mC(_,_,batch_idx)); + auto copy_d = get_block_2d_copy_D(tiled_mma, params.mD(_,_,batch_idx)); + + + // Represent the full output tensor Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); - // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) - - // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + // Tile the output tensor for the current workgroup + Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) - auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); - Tensor tCgC = thread_xe_load_c.partition_S(gD); + // Get thread-level partitioning across the entire workgroup tile + auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx); + Tensor tCgC = reshape_with_unit_insertion(thread_xe_load_c.partition_S(gD)); - auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); - Tensor tCgD = thread_xe_store_d.partition_D(gD); + auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx); + Tensor tCgD = reshape_with_unit_insertion(thread_xe_store_d.partition_D(gD)); Tensor trC = make_tensor(Shape>{}); Tensor trD_compute = make_tensor(Shape>{}); - - // Because Sm90 uses shared memory, they are not tied to using the same accumulator values - // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be - // sure that we are operating on the same values. - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + ThrCopy thread_g2r = copy_c.get_slice(thread_idx); + auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{}); // OOB predication for tile quantization "residue" // Absolute coordinate tensors (dynamic) @@ -364,7 +380,7 @@ class CollectiveEpilogue< Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // Get the fusion callbacks // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles @@ -376,7 +392,7 @@ class CollectiveEpilogue< sg_coord, tiled_mma, mn_shape, - params.xe_store_d, + copy_d, cD, residue_mn, tRS_cD, @@ -398,7 +414,8 @@ class CollectiveEpilogue< FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); - + + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -407,7 +424,7 @@ class CollectiveEpilogue< cst_callbacks.begin_loop(epi_m, epi_n); if (is_C_load_needed) { - copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + copy(copy_c, tCgC(_, epi_m, epi_n), trC); } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -419,21 +436,23 @@ class CollectiveEpilogue< trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); - + if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } - copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); + copy(copy_d, trD, tCgD(_, epi_m, epi_n)); } cst_callbacks.end_loop(epi_m, epi_n); + } } cst_callbacks.end(); - } + +} private: Params const& params; @@ -447,4 +466,4 @@ class CollectiveEpilogue< } // namespace epilogue } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp new file mode 100644 index 0000000000..05437ab180 --- /dev/null +++ b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelXeXMX16, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; + using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = ElementAccumulator; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + + using Trait_C = Copy_Traits; + using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_load_C{})); + + using Trait_D = Copy_Traits; + using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); + +private: + constexpr static bool is_source_supported = not cute::is_void_v && not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); + xe_load_c = {xe_load_c.with(mC)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); + xe_store_d = {xe_store_d.with(mD)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + bool fusion_implementable = true; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dD); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dD) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dC); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dC) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx) { + + (void) tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + + auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); + Tensor tCgC = thread_xe_load_c.partition_S(gD); + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + mn_shape, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_C_load_needed) { + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); + } + + cst_callbacks.end_loop(epi_m, epi_n); + } + } + + cst_callbacks.end(); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index fbb4a40b52..b661d51c10 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -304,6 +304,12 @@ struct IntelXeXMX16 { struct IntelXeXMX16Group { static constexpr int SubgroupSize = 16; }; + +// Note: This dispatch policy is specifically added for CollectiveEpilogue to support +// the integration of new MMA atoms (XE_DPAS_TT) and copy atoms for Intel XE architecture +struct IntelXeL1Staged { + static constexpr int SubgroupSize = 16; +}; #endif ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..71220d281d 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -109,6 +109,59 @@ struct FusionCallbacks< using Impl::Impl; }; +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeL1Staged, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { + + using Impl = Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + template < template class ActivationFn_,