From 578ae95e97db23111318545d6b9065f63a3d48fc Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Wed, 22 Oct 2025 09:22:38 +0300 Subject: [PATCH 1/7] Add new atoms in collectiveEpilogue --- examples/00_bmg_gemm/00_bmg_gemm.cpp | 8 +- .../collective/collective_epilogue.hpp | 1 + .../epilogue/collective/xe_epilogue.hpp | 241 +++++----- .../collective/xe_epilogue_legacy.cpp | 450 ++++++++++++++++++ include/cutlass/epilogue/dispatch_policy.hpp | 6 + .../cutlass/epilogue/fusion/xe_callbacks.hpp | 53 +++ 6 files changed, 629 insertions(+), 130 deletions(-) create mode 100644 include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index 7e9291227e..90a15e00ad 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -348,6 +348,8 @@ int main(int argc, const char** argv) // The 2D block copy operations used for the A and B matrices using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyC = void; //XE_LOAD_2D<32, 8, 16>; + using GmemTiledCopyD = void; //XE_STORE_2D<32, 8, 16>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -369,7 +371,7 @@ int main(int argc, const char** argv) // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. constexpr int PipelineStages = 2; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeL1Staged; // This is the 'default' epilogue operation (Linear Combination) which performs everything in: // (D = alpha * (A*B) + beta * C) @@ -392,9 +394,9 @@ int main(int argc, const char** argv) ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + GmemTiledCopyC, // The copy atom used to load matrix C void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + GmemTiledCopyD, // The copy atom used to store matrix D void, void>; // GEMM Mainloop - iteration over blocks in K dimension diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 2e7fbeb579..4f6d041363 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -71,6 +71,7 @@ class CollectiveEpilogue { #include "sm100_epilogue_array_tma_warpspecialized.hpp" #if defined (SYCL_INTEL_TARGET) #include "xe_epilogue.hpp" +#include "xe_epilogue_legacy.hpp" #include "xe_array_epilogue.hpp" #endif // diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 05437ab180..6c5720528d 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -68,7 +68,7 @@ template < class CopyOpR2S_ > class CollectiveEpilogue< - IntelXeXMX16, + IntelXeL1Staged, CtaTileMNK_, ElementC_, StrideC_, @@ -86,7 +86,7 @@ class CollectiveEpilogue< // // Type Aliases // - using DispatchPolicy = IntelXeXMX16; + using DispatchPolicy = IntelXeL1Staged; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; @@ -102,9 +102,6 @@ class CollectiveEpilogue< using CopyOpR2S = CopyOpR2S_; using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; - using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, - CopyOpR2G, XE_2D_U32x8x16_ST_N>; using ElementOutput = ElementD; using ElementCompute = ElementAccumulator; @@ -119,19 +116,10 @@ class CollectiveEpilogue< static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - using CopyThreadShape = Shape<_1, Int>; - - using Trait_C = Copy_Traits; - using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_load_C{})); - - using Trait_D = Copy_Traits; - using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); - +//remember this PR https://github.com/intel/sycl-tla/pull/565/files private: - constexpr static bool is_source_supported = not cute::is_void_v && not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -166,8 +154,11 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { typename FusionCallbacks::Params thread{}; - XE_Copy_C xe_load_c; - XE_Copy_D xe_store_d; + ElementC const* ptr_C; + ElementD* ptr_D; + int M, N, K, L; + StrideC dC; + StrideD dD; }; // @@ -184,22 +175,13 @@ class CollectiveEpilogue< auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - XE_Copy_C xe_load_c = {}; - if constexpr (is_source_supported) { - auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); - xe_load_c = {xe_load_c.with(mC)}; - } - - XE_Copy_D xe_store_d = {}; - if constexpr (is_destination_supported) { - auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); - xe_store_d = {xe_store_d.with(mD)}; - } - return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - xe_load_c, - xe_store_d, + args.ptr_C, + args.ptr_D, + M, N, K, L, + args.dC, + args.dD }; } @@ -286,7 +268,6 @@ class CollectiveEpilogue< TiledMma tiled_mma, int thread_idx) { - (void) tiled_mma; using namespace cute; static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); @@ -297,143 +278,149 @@ class CollectiveEpilogue< static constexpr auto BLK_M = get<0>(CtaTileMNK{}); static constexpr auto BLK_N = get<1>(CtaTileMNK{}); static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, "assertation fail"); - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static_assert( - BLK_M % ATOM_M == 0 && - BLK_N % ATOM_N == 0 && - BLK_K % ATOM_K == 0, - "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); - static constexpr auto SG_M = BLK_M / ATOM_M; - static constexpr auto SG_N = BLK_N / ATOM_N; - static constexpr auto SG_K = BLK_K / ATOM_K; - using SubgroupTileShape = Shape; - - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - - static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; - + // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto m_sg = get_sub_group_id() / ATOM_N; - auto n_sg = get_sub_group_id() % ATOM_N; - - auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); - - auto sg_local_m_coord = get_sub_group_id() / ATOM_N; - auto sg_local_n_coord = get_sub_group_id() % ATOM_N; - - auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; - auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; - auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + // Workgroup coordinates (no subgroup indexing needed) + auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord); + auto batch_idx = get<3>(wg_coord); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + auto mC = make_tensor(make_gmem_ptr(params.ptr_C), make_layout(make_shape(params.M, params.N, params.L), params.dC)); + auto mD = make_tensor(make_gmem_ptr(params.ptr_D), make_layout(make_shape(params.M, params.N, params.L), params.dD)); + + auto copy_c = [&]() { + if constexpr (!std::is_void_v) { + return make_block_2d_copy_A(CopyOpG2R{}, tiled_mma, mC(_,_,batch_idx)); + } else { + return make_block_2d_copy_A(tiled_mma, mC(_,_,batch_idx)); + } + }(); + auto copy_d = [&]() { + if constexpr (!std::is_void_v) { + return make_block_2d_copy_C(CopyOpR2G{}, tiled_mma, mD(_,_,batch_idx)); + } else { + return make_block_2d_copy_C(tiled_mma, mD(_,_,batch_idx)); + } + }(); + // Represent the full output tensor Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); - // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) - - // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + // Tile the output tensor for the current workgroup + Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) // change made - auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); + // Get thread-level partitioning across the entire workgroup tile + auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx); Tensor tCgC = thread_xe_load_c.partition_S(gD); - auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); - Tensor trC = make_tensor(Shape>{}); - Tensor trD_compute = make_tensor(Shape>{}); + // Create tensors sized for workgroup-level operation + Tensor trC = make_tensor(tCgC.shape()); + Tensor trD_compute = make_tensor(tCgD.shape()); + Tensor trD = make_tensor(tCgD.shape()); + + ThrCopy thread_g2r = copy_c.get_slice(thread_idx); - // Because Sm90 uses shared memory, they are not tied to using the same accumulator values - // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be - // sure that we are operating on the same values. - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{}); // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); + Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); + Tensor tRS_cD = thread_g2r.partition_S(flat_divide(cD, mn_shape)); - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor tRS_cD_coord = make_coord_tensor(tRS_cD.layout()); - // Get the fusion callbacks - // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + // Get fusion callbacks at workgroup level constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto residue_mn = make_coord(M, N); auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ problem_shape_mnkl, - SubgroupTileShape{}, - sg_coord, + CtaTileMNK{}, // Use workgroup tile shape + wg_coord, // Use workgroup coordinates tiled_mma, mn_shape, - params.xe_store_d, + copy_d, cD, residue_mn, - tRS_cD, + tRS_cD_coord, residue_mn, trC, thread_idx, }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + auto synchronize = [&] () {}; cst_callbacks.begin(); + // Load C tile if needed (distributed across all threads in workgroup) + if (is_C_load_needed) { + copy(copy_c, tCgC, trC); + } - auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); - - Tensor trD = make_tensor(Shape>{}); - auto trD_frag = recast>(trD); - - constexpr int ValuesLoaded = - FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; - constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); - static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + // Single previsit for entire workgroup tile + cst_callbacks.previsit(0, 0, 0, is_C_load_needed); - auto synchronize = [&] () {}; + static constexpr int FragmentSize = get<0>(MmaAtomShape()) * get<1>(MmaAtomShape()); + constexpr int num_fragments = size(accumulators) / FragmentSize; + CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < FragsN; epi_n++) { + for (int epi_v = 0; epi_v < num_fragments; ++epi_v) { + // Extract fragment + Array frg_acc; CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < FragsM; epi_m++) { - cst_callbacks.begin_loop(epi_m, epi_n); - - if (is_C_load_needed) { - copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); - } - - cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + for (int f = 0; f < FragmentSize; ++f) { + frg_acc[f] = accumulators(epi_v * FragmentSize + f); + } + + // Process fragment + auto result_frg = cst_callbacks.visit(frg_acc, epi_v, 0, 0); + + // Store results + CUTLASS_PRAGMA_UNROLL + for (int f = 0; f < FragmentSize; ++f) { + trD_compute(epi_v * FragmentSize + f) = result_frg[f]; + } + } - auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + cst_callbacks.reduce(nullptr, synchronize, 0, 0, true, trD_compute); + if constexpr (is_destination_supported) { + // Convert fragments using NumericArrayConverter + constexpr int num_fragments_trD_compute = size(trD_compute) / FragmentSize; + using Converter = cutlass::NumericArrayConverter; + Converter converter{}; + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < num_fragments_trD_compute; ++epi_v) { + // Extract compute fragment + Array trD_compute_frag; CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + for (int f = 0; f < FragmentSize; ++f) { + trD_compute_frag[f] = trD_compute(epi_v * FragmentSize + f); } - cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); - if constexpr (is_destination_supported) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); - } - copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); - } + // Convert fragment + auto trD_frag = converter(trD_compute_frag); - cst_callbacks.end_loop(epi_m, epi_n); + // Store converted fragment + CUTLASS_PRAGMA_UNROLL + for (int f = 0; f < FragmentSize; ++f) { + trD(epi_v * FragmentSize + f) = trD_frag[f]; + + } } - } + + copy(copy_d, trD, tCgD); + } - cst_callbacks.end(); - } + cst_callbacks.end(); + +} private: Params const& params; @@ -447,4 +434,4 @@ class CollectiveEpilogue< } // namespace epilogue } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp b/include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp new file mode 100644 index 0000000000..b2b029bff2 --- /dev/null +++ b/include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelXeL1Staged, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeL1Staged; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; + using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = ElementAccumulator; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + + using Trait_C = Copy_Traits; + using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_load_C{})); + + using Trait_D = Copy_Traits; + using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); + +private: + constexpr static bool is_source_supported = not cute::is_void_v && not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); + xe_load_c = {xe_load_c.with(mC)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); + xe_store_d = {xe_store_d.with(mD)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + bool fusion_implementable = true; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dD); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dD) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dC); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dC) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx) { + + (void) tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + + auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); + Tensor tCgC = thread_xe_load_c.partition_S(gD); + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + mn_shape, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_C_load_needed) { + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); + } + + cst_callbacks.end_loop(epi_m, epi_n); + } + } + + cst_callbacks.end(); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index fbb4a40b52..b661d51c10 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -304,6 +304,12 @@ struct IntelXeXMX16 { struct IntelXeXMX16Group { static constexpr int SubgroupSize = 16; }; + +// Note: This dispatch policy is specifically added for CollectiveEpilogue to support +// the integration of new MMA atoms (XE_DPAS_TT) and copy atoms for Intel XE architecture +struct IntelXeL1Staged { + static constexpr int SubgroupSize = 16; +}; #endif ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..71220d281d 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -109,6 +109,59 @@ struct FusionCallbacks< using Impl::Impl; }; +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeL1Staged, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { + + using Impl = Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + template < template class ActivationFn_, From aaf268514226b49b8c25de592b606de469afc189 Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Wed, 22 Oct 2025 12:01:26 +0300 Subject: [PATCH 2/7] Use new make_block_2d_copy_{C,D} APIs for loads/stores --- include/cutlass/epilogue/collective/xe_epilogue.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 6c5720528d..af56ace945 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -293,16 +293,16 @@ class CollectiveEpilogue< auto copy_c = [&]() { if constexpr (!std::is_void_v) { - return make_block_2d_copy_A(CopyOpG2R{}, tiled_mma, mC(_,_,batch_idx)); + return make_block_2d_copy_CD(CopyOpG2R{}, tiled_mma, mC(_,_,batch_idx)); } else { - return make_block_2d_copy_A(tiled_mma, mC(_,_,batch_idx)); + return make_block_2d_copy_C(tiled_mma, mC(_,_,batch_idx)); } }(); auto copy_d = [&]() { if constexpr (!std::is_void_v) { - return make_block_2d_copy_C(CopyOpR2G{}, tiled_mma, mD(_,_,batch_idx)); + return make_block_2d_copy_CD(CopyOpR2G{}, tiled_mma, mD(_,_,batch_idx)); } else { - return make_block_2d_copy_C(tiled_mma, mD(_,_,batch_idx)); + return make_block_2d_copy_D(tiled_mma, mD(_,_,batch_idx)); } }(); From 407a875781bbb076eed133a2da9331d878ac1d8c Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Wed, 22 Oct 2025 16:56:46 +0300 Subject: [PATCH 3/7] Code Cleanup --- include/cute/atom/copy_traits_xe_2d.hpp | 42 +++++++++++++++++ .../epilogue/collective/xe_epilogue.hpp | 46 ++++++++----------- 2 files changed, 60 insertions(+), 28 deletions(-) diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 7c526a405e..3a579a4ec4 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -1079,6 +1079,48 @@ make_block_2d_prefetch(PrefetchOp const& op, return make_block_2d_copy(op, stride, x_mode, y_mode, atom_shape, sv_layout); } +// +// Block 2D Copy Utilities - Helper functions for conditional copy operation selection +// +template +auto get_block_2d_copy_A(TiledMMA const& tiled_mma, ATensor const& a_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_A(CopyOp{}, tiled_mma, a_tensor); + } else { + return make_block_2d_copy_A(tiled_mma, a_tensor); + } +} + +template +auto get_block_2d_copy_B(TiledMMA const& tiled_mma, BTensor const& b_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_B(CopyOp{}, tiled_mma, b_tensor); + } else { + return make_block_2d_copy_B(tiled_mma, b_tensor); + } +} + +template +auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor); + } else { + return make_block_2d_copy_C(tiled_mma, c_tensor); + } +} + +template +auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor); + } else { + return make_block_2d_copy_D(tiled_mma, d_tensor); + } +} // diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index af56ace945..231ca44f98 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -142,6 +142,15 @@ class CollectiveEpilogue< }; using TensorStorage = typename SharedStorage::TensorStorage; + // Helper to get tensor types + template + using TensorTypeC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_layout(make_shape(int{}, int{}, int{}), Stride{}))); + + template + using TensorTypeD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_layout(make_shape(int{}, int{}, int{}), Stride{}))); + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; @@ -154,11 +163,8 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { typename FusionCallbacks::Params thread{}; - ElementC const* ptr_C; - ElementD* ptr_D; - int M, N, K, L; - StrideC dC; - StrideD dD; + TensorTypeC mC; + TensorTypeD mD; }; // @@ -174,14 +180,13 @@ class CollectiveEpilogue< // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; + auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - args.ptr_C, - args.ptr_D, - M, N, K, L, - args.dC, - args.dD + mC, + mD }; } @@ -285,27 +290,12 @@ class CollectiveEpilogue< // Workgroup coordinates (no subgroup indexing needed) auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord); - auto batch_idx = get<3>(wg_coord); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - auto mC = make_tensor(make_gmem_ptr(params.ptr_C), make_layout(make_shape(params.M, params.N, params.L), params.dC)); - auto mD = make_tensor(make_gmem_ptr(params.ptr_D), make_layout(make_shape(params.M, params.N, params.L), params.dD)); - - auto copy_c = [&]() { - if constexpr (!std::is_void_v) { - return make_block_2d_copy_CD(CopyOpG2R{}, tiled_mma, mC(_,_,batch_idx)); - } else { - return make_block_2d_copy_C(tiled_mma, mC(_,_,batch_idx)); - } - }(); - auto copy_d = [&]() { - if constexpr (!std::is_void_v) { - return make_block_2d_copy_CD(CopyOpR2G{}, tiled_mma, mD(_,_,batch_idx)); - } else { - return make_block_2d_copy_D(tiled_mma, mD(_,_,batch_idx)); - } - }(); + auto batch_idx = get<3>(wg_coord); + auto copy_c = get_block_2d_copy_C(tiled_mma, params.mC(_,_,batch_idx)); + auto copy_d = get_block_2d_copy_D(tiled_mma, params.mD(_,_,batch_idx)); // Represent the full output tensor Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); From 12282e8f782db7774aef507e7cba319c9b56452f Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Thu, 23 Oct 2025 11:39:58 +0530 Subject: [PATCH 4/7] Rename xe_epilogue_legacy.cpp to xe_epilogue_legacy.hpp Added right extension --- .../collective/{xe_epilogue_legacy.cpp => xe_epilogue_legacy.hpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/cutlass/epilogue/collective/{xe_epilogue_legacy.cpp => xe_epilogue_legacy.hpp} (100%) diff --git a/include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp similarity index 100% rename from include/cutlass/epilogue/collective/xe_epilogue_legacy.cpp rename to include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp From 6596ac82a419b40f51684f48e89f03c43c5da2f0 Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Thu, 23 Oct 2025 11:54:43 +0530 Subject: [PATCH 5/7] Update xe_epilogue_legacy.hpp Added legacy dispatchpolicy --- include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp index b2b029bff2..05437ab180 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp @@ -68,7 +68,7 @@ template < class CopyOpR2S_ > class CollectiveEpilogue< - IntelXeL1Staged, + IntelXeXMX16, CtaTileMNK_, ElementC_, StrideC_, @@ -86,7 +86,7 @@ class CollectiveEpilogue< // // Type Aliases // - using DispatchPolicy = IntelXeL1Staged; + using DispatchPolicy = IntelXeXMX16; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; From 7ea69d52291dba585401d1c086df3edc6b4d0531 Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Thu, 23 Oct 2025 18:13:37 +0300 Subject: [PATCH 6/7] Avoid register spills --- examples/00_bmg_gemm/00_bmg_gemm.cpp | 4 +- .../epilogue/collective/xe_epilogue.hpp | 170 +++++++++++------- 2 files changed, 103 insertions(+), 71 deletions(-) diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index 040039b88b..d7f5178785 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -350,8 +350,8 @@ int main(int argc, const char** argv) // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; - using GmemTiledCopyC = void; //XE_LOAD_2D<32, 8, 16>; - using GmemTiledCopyD = void; //XE_STORE_2D<32, 8, 16>; + using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>; + using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>; diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 231ca44f98..c910a602a4 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -257,6 +257,24 @@ class CollectiveEpilogue< return fusion_callbacks.is_producer_load_needed(); } + template + CUTLASS_DEVICE auto reshape_into_smaller_fragments(Tensor&& tensor) { + using namespace cute; + + auto target_stride = make_stride( + make_stride(cute::ScaledBasis, 0>{}, _0{}), + cute::ScaledBasis, 0>{}, + cute::ScaledBasis, 1>{} + ); + + auto target_layout = make_layout( + make_shape(make_shape(_8{}, _1{}), _4{}, _4{}), + target_stride + ); + + return make_tensor(tensor.data(), target_layout); +} + template< class ProblemShapeMNKL, class TileShapeMNK, @@ -283,25 +301,52 @@ class CollectiveEpilogue< static constexpr auto BLK_M = get<0>(CtaTileMNK{}); static constexpr auto BLK_N = get<1>(CtaTileMNK{}); static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); - // Workgroup coordinates (no subgroup indexing needed) auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord); - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); auto batch_idx = get<3>(wg_coord); auto copy_c = get_block_2d_copy_C(tiled_mma, params.mC(_,_,batch_idx)); auto copy_d = get_block_2d_copy_D(tiled_mma, params.mD(_,_,batch_idx)); + + // Represent the full output tensor Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); // Tile the output tensor for the current workgroup - Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) // change made + Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) // Get thread-level partitioning across the entire workgroup tile auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx); @@ -310,105 +355,92 @@ class CollectiveEpilogue< auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); - // Create tensors sized for workgroup-level operation - Tensor trC = make_tensor(tCgC.shape()); - Tensor trD_compute = make_tensor(tCgD.shape()); - Tensor trD = make_tensor(tCgD.shape()); + auto tCgC_frag = reshape_into_smaller_fragments(tCgC); + auto tCgD_frag = reshape_into_smaller_fragments(tCgD); + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); ThrCopy thread_g2r = copy_c.get_slice(thread_idx); - auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{}); // OOB predication for tile quantization "residue" - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); - Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); - Tensor tRS_cD = thread_g2r.partition_S(flat_divide(cD, mn_shape)); + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor tRS_cD_coord = make_coord_tensor(tRS_cD.layout()); + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); - // Get fusion callbacks at workgroup level + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ problem_shape_mnkl, - CtaTileMNK{}, // Use workgroup tile shape - wg_coord, // Use workgroup coordinates + SubgroupTileShape{}, + sg_coord, tiled_mma, mn_shape, copy_d, cD, residue_mn, - tRS_cD_coord, + tRS_cD, residue_mn, trC, thread_idx, }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - auto synchronize = [&] () {}; cst_callbacks.begin(); - // Load C tile if needed (distributed across all threads in workgroup) - if (is_C_load_needed) { - copy(copy_c, tCgC, trC); - } - // Single previsit for entire workgroup tile - cst_callbacks.previsit(0, 0, 0, is_C_load_needed); - - static constexpr int FragmentSize = get<0>(MmaAtomShape()) * get<1>(MmaAtomShape()); - constexpr int num_fragments = size(accumulators) / FragmentSize; + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < num_fragments; ++epi_v) { - // Extract fragment - Array frg_acc; + for (int epi_n = 0; epi_n < FragsN; epi_n++) { CUTLASS_PRAGMA_UNROLL - for (int f = 0; f < FragmentSize; ++f) { - frg_acc[f] = accumulators(epi_v * FragmentSize + f); - } - - // Process fragment - auto result_frg = cst_callbacks.visit(frg_acc, epi_v, 0, 0); - - // Store results - CUTLASS_PRAGMA_UNROLL - for (int f = 0; f < FragmentSize; ++f) { - trD_compute(epi_v * FragmentSize + f) = result_frg[f]; - } - } + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + cst_callbacks.begin_loop(epi_m, epi_n); - cst_callbacks.reduce(nullptr, synchronize, 0, 0, true, trD_compute); + if (is_C_load_needed) { + copy(copy_c, tCgC_frag(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); - if constexpr (is_destination_supported) { - // Convert fragments using NumericArrayConverter - constexpr int num_fragments_trD_compute = size(trD_compute) / FragmentSize; - using Converter = cutlass::NumericArrayConverter; - Converter converter{}; - - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < num_fragments_trD_compute; ++epi_v) { - // Extract compute fragment - Array trD_compute_frag; CUTLASS_PRAGMA_UNROLL - for (int f = 0; f < FragmentSize; ++f) { - trD_compute_frag[f] = trD_compute(epi_v * FragmentSize + f); + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } - - // Convert fragment - auto trD_frag = converter(trD_compute_frag); - - // Store converted fragment - CUTLASS_PRAGMA_UNROLL - for (int f = 0; f < FragmentSize; ++f) { - trD(epi_v * FragmentSize + f) = trD_frag[f]; - + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(copy_d, trD, tCgD_frag(_, epi_m, epi_n)); } + + cst_callbacks.end_loop(epi_m, epi_n); + } - - copy(copy_d, trD, tCgD); - } + } - cst_callbacks.end(); + cst_callbacks.end(); } From 6944d9075c7e7c81723009375e9324e73d806414 Mon Sep 17 00:00:00 2001 From: Anamika Chatterjee Date: Fri, 24 Oct 2025 12:03:37 +0300 Subject: [PATCH 7/7] remove hardcode layout --- .../epilogue/collective/xe_epilogue.hpp | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index c910a602a4..0a5abe7b78 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -258,21 +258,34 @@ class CollectiveEpilogue< } template - CUTLASS_DEVICE auto reshape_into_smaller_fragments(Tensor&& tensor) { + CUTLASS_DEVICE auto reshape_with_unit_insertion(Tensor&& tensor) { using namespace cute; - auto target_stride = make_stride( - make_stride(cute::ScaledBasis, 0>{}, _0{}), - cute::ScaledBasis, 0>{}, - cute::ScaledBasis, 1>{} + auto orig_layout = tensor.layout(); + auto orig_shape = orig_layout.shape(); + auto orig_stride = orig_layout.stride(); + + auto first_dim = get<0>(orig_shape); + auto outer_part = get<0>(first_dim); + auto inner_part = get<1>(first_dim); + + auto first_stride = get<0>(orig_stride); + auto outer_stride = get<0>(first_stride); + auto inner_stride = get<1>(first_stride); + + auto target_shape = make_shape( + make_shape(outer_part, _1{}), + get<0>(inner_part), + get<1>(inner_part) ); - auto target_layout = make_layout( - make_shape(make_shape(_8{}, _1{}), _4{}, _4{}), - target_stride + auto target_stride = make_stride( + make_stride(outer_stride, _0{}), + get<0>(inner_stride), + get<1>(inner_stride) ); - return make_tensor(tensor.data(), target_layout); + return make_tensor(tensor.data(), make_layout(target_shape, target_stride)); } template< @@ -350,13 +363,10 @@ class CollectiveEpilogue< // Get thread-level partitioning across the entire workgroup tile auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx); - Tensor tCgC = thread_xe_load_c.partition_S(gD); + Tensor tCgC = reshape_with_unit_insertion(thread_xe_load_c.partition_S(gD)); auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx); - Tensor tCgD = thread_xe_store_d.partition_D(gD); - - auto tCgC_frag = reshape_into_smaller_fragments(tCgC); - auto tCgD_frag = reshape_into_smaller_fragments(tCgD); + Tensor tCgD = reshape_with_unit_insertion(thread_xe_store_d.partition_D(gD)); Tensor trC = make_tensor(Shape>{}); Tensor trD_compute = make_tensor(Shape>{}); @@ -414,7 +424,7 @@ class CollectiveEpilogue< cst_callbacks.begin_loop(epi_m, epi_n); if (is_C_load_needed) { - copy(copy_c, tCgC_frag(_, epi_m, epi_n), trC); + copy(copy_c, tCgC(_, epi_m, epi_n), trC); } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -432,7 +442,7 @@ class CollectiveEpilogue< for (int i = 0; i < size(trD_compute_frag); ++i) { trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } - copy(copy_d, trD, tCgD_frag(_, epi_m, epi_n)); + copy(copy_d, trD, tCgD(_, epi_m, epi_n)); } cst_callbacks.end_loop(epi_m, epi_n);