Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ int main(int argc, const char** argv)
// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
using GmemTiledCopyC = void; //XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = void; //XE_STORE_2D<32, 8, 16>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand All @@ -369,7 +371,7 @@ int main(int argc, const char** argv)
// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeL1Staged;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
// (D = alpha * (A*B) + beta * C)
Expand All @@ -392,9 +394,9 @@ int main(int argc, const char** argv)
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
FusionCallBacks,
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
GmemTiledCopyC, // The copy atom used to load matrix C
void, void,
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
GmemTiledCopyD, // The copy atom used to store matrix D
void, void>;

// GEMM Mainloop - iteration over blocks in K dimension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class CollectiveEpilogue {
#include "sm100_epilogue_array_tma_warpspecialized.hpp"
#if defined (SYCL_INTEL_TARGET)
#include "xe_epilogue.hpp"
#include "xe_epilogue_legacy.hpp"
#include "xe_array_epilogue.hpp"
#endif
//
Expand Down
241 changes: 114 additions & 127 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ template <
class CopyOpR2S_
>
class CollectiveEpilogue<
IntelXeXMX16,
IntelXeL1Staged,
CtaTileMNK_,
ElementC_,
StrideC_,
Expand All @@ -86,7 +86,7 @@ class CollectiveEpilogue<
//
// Type Aliases
//
using DispatchPolicy = IntelXeXMX16;
using DispatchPolicy = IntelXeL1Staged;
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
Expand All @@ -102,9 +102,6 @@ class CollectiveEpilogue<
using CopyOpR2S = CopyOpR2S_;

using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_2D_U32x8x16_LD_N, CopyOpG2R>;
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
using ElementOutput = ElementD;
using ElementCompute = ElementAccumulator;

Expand All @@ -119,19 +116,10 @@ class CollectiveEpilogue<
static_assert(std::is_same_v<SmemLayoutAtomC, void>, "Copy operation to shared memory is not supported");
static_assert(std::is_same_v<SmemLayoutAtomD, void>, "Copy operation to shared memory is not supported");

using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;

using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})));
using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<Trait_C, ElementC>{}, Layout<CopyThreadShape>{}, val_layout_load_C{}));

using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));

//remember this PR https://github.com/intel/sycl-tla/pull/565/files
private:
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiyang1011 - The validation logic from PR #565 that sets is_source_supported to false when CopyOpG2R is void needs updating. With this PR's automatic ops selection, both CopyOpG2R and CopyOpR2G can now legitimately be void since make_block_2d_copy_* automatically selects appropriate operations.

Copy link

@jiyang1011 jiyang1011 Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we set a default copy trait like XeCopyAuto or something else which will also call make_block_2d_copy_* ?

constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>;

constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
Expand Down Expand Up @@ -166,8 +154,11 @@ class CollectiveEpilogue<
// Device side epilogue params
struct Params {
typename FusionCallbacks::Params thread{};
XE_Copy_C xe_load_c;
XE_Copy_D xe_store_d;
ElementC const* ptr_C;
ElementD* ptr_D;
int M, N, K, L;
StrideC dC;
StrideD dD;
};

//
Expand All @@ -184,22 +175,13 @@ class CollectiveEpilogue<
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;

XE_Copy_C xe_load_c = {};
if constexpr (is_source_supported) {
auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC));
xe_load_c = {xe_load_c.with(mC)};
}

XE_Copy_D xe_store_d = {};
if constexpr (is_destination_supported) {
auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD));
xe_store_d = {xe_store_d.with(mD)};
}

return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
xe_load_c,
xe_store_d,
args.ptr_C,
args.ptr_D,
M, N, K, L,
args.dC,
args.dD
};
}

Expand Down Expand Up @@ -286,7 +268,6 @@ class CollectiveEpilogue<
TiledMma tiled_mma,
int thread_idx) {

(void) tiled_mma;
using namespace cute;

static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
Expand All @@ -297,143 +278,149 @@ class CollectiveEpilogue<
static constexpr auto BLK_M = get<0>(CtaTileMNK{});
static constexpr auto BLK_N = get<1>(CtaTileMNK{});
static constexpr auto BLK_K = get<2>(CtaTileMNK{});
// static_assert(is_same_v<typename TiledMma::ThrLayoutVMNK, int>, "assertation fail");
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());

static_assert(
BLK_M % ATOM_M == 0 &&
BLK_N % ATOM_N == 0 &&
BLK_K % ATOM_K == 0,
"expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK");
static constexpr auto SG_M = BLK_M / ATOM_M;
static constexpr auto SG_N = BLK_N / ATOM_N;
static constexpr auto SG_K = BLK_K / ATOM_K;
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;

static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group

static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;


// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto m_sg = get_sub_group_id() / ATOM_N;
auto n_sg = get_sub_group_id() % ATOM_N;

auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{});

auto sg_local_m_coord = get_sub_group_id() / ATOM_N;
auto sg_local_n_coord = get_sub_group_id() % ATOM_N;

auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);

// Workgroup coordinates (no subgroup indexing needed)
auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord);
auto batch_idx = get<3>(wg_coord);

bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
auto mC = make_tensor(make_gmem_ptr(params.ptr_C), make_layout(make_shape(params.M, params.N, params.L), params.dC));
auto mD = make_tensor(make_gmem_ptr(params.ptr_D), make_layout(make_shape(params.M, params.N, params.L), params.dD));

auto copy_c = [&]() {
if constexpr (!std::is_void_v<CopyOpG2R>) {
return make_block_2d_copy_CD(CopyOpG2R{}, tiled_mma, mC(_,_,batch_idx));
} else {
return make_block_2d_copy_C(tiled_mma, mC(_,_,batch_idx));
}
}();
auto copy_d = [&]() {
if constexpr (!std::is_void_v<CopyOpR2G>) {
return make_block_2d_copy_CD(CopyOpR2G{}, tiled_mma, mD(_,_,batch_idx));
} else {
return make_block_2d_copy_D(tiled_mma, mD(_,_,batch_idx));
}
}();


// Represent the full output tensor
Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L));

// Tile the output tensor per WG and select the tile for current WG
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)

// Tile the output tensor per SG and select tile for the current SG
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N)
// Tile the output tensor for the current workgroup
Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) // change made

auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx);
// Get thread-level partitioning across the entire workgroup tile
auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx);
Tensor tCgC = thread_xe_load_c.partition_S(gD);
Copy link
Author

@anamikac-intel anamikac-intel Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@petercad - I was trying to fix register spills... reducing copy operations not helping much. Issue seems to coming from tiling done.. In legacy code after partition tCgC and tCgD was :
tCgC : ArithTuple(0,0,0) o ((_8,_1),_4,_4):((_1@0,_0),_8@0,_16@1)
tCgD : ArithTuple(0,0,0) o ((_8,_1),_4,_4):((_1@0,_0),_8@0,_16@1)

So we have 8 fragments of size 4 x 4

whereas in new code we have 128 fragments of size 1 x 1:

tCgC: ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0)
tCgD: ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0)

I tried titling further with SubgroupTileShape{}: (_32,_64,_32) but same result
g_wg_D: ArithTuple(0,0,0) o (_256,_256):(_1@0,_1@1)
gD: ArithTuple(0,0,0) o (_32,_64):(_1@0,_1@1)
tCgC: ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0)
tCgD: ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0)

This seems to be the actual issue.. so when I reshaped the layout to tCgC/tCgD ArithTuple(0,0,0) o ((_8,_1),_4,_4):((_1@0,_0),_8@0,_16@1) (8 fragments of 4x4) the perf drop is fixed. But re-layouting tCgC/tCgD might not be best option so can you please check.

image

Copy link

@sanchitintel sanchitintel Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have 8 fragments of size 4 x 4

No, the per-thread D or C fragment size is 128 elements in both cases, but layout is different.
The WG tile size is (256, 256, 32).
There are 32 subgroups, each with 16 threads.
Each SG fragment for C or D is (32, 64) spatially since the subgroup layout is 8x4 in the example.
Each C or D thread fragment is sized 128 elements.

Both ((_8,_1),_4,_4) and ((_8,(_4,_4)),_1,_1) have 128 elements.

This comment was marked as outdated.

This comment was marked as duplicate.

Copy link

@sanchitintel sanchitintel Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @petercad, recent commits having a lot more changes than just the thread fragment layout (seems to be equivalent to the previous one) seem to suggest that there's a lurking factor that fixed the performance issues that were observed earlier in this PR, and that the thread-fragment layout of new C, D copy atoms isn't problematic.

Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this code close to legacy one except I am using new copy atoms with the reshaping layout... but only concern is it only works with ops that has 16 width × 8 height #573 (comment)

Copy link

@petercad petercad Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchitintel -- will look at it in more detail shortly. The underlying cause for earlier regressions seems to be twofold:

  • Code scheduling issues in IGC. It seems it is not moving loads/stores around sufficiently to reduce register pressure.
  • For the C loads, make_block_2d_copy_c will try to make the blocks as large as possible (because it's operating on the assumption that you're loading all of C at once) but that brings additional register pressure

The second point is not enough to explain the spills (there is plenty of register space even if you do load huge chunks of C), but it aggravates the first point.


auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});
// Create tensors sized for workgroup-level operation
Tensor trC = make_tensor<typename TiledMma::ValTypeC>(tCgC.shape());
Tensor trD_compute = make_tensor<ElementCompute>(tCgD.shape());
Tensor trD = make_tensor<ElementOutput>(tCgD.shape());

ThrCopy thread_g2r = copy_c.get_slice(thread_idx);

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
// sure that we are operating on the same values.
ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx);
auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{});

// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord));
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
Tensor mD_crd = make_identity_tensor(make_shape(M,N));
Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord));
Tensor tRS_cD = thread_g2r.partition_S(flat_divide(cD, mn_shape));

Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
Tensor tRS_cD_coord = make_coord_tensor(tRS_cD.layout());

// Get the fusion callbacks
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
// Get fusion callbacks at workgroup level
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct
auto residue_mn = make_coord(M, N);
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
SubgroupTileShape{},
sg_coord,
CtaTileMNK{}, // Use workgroup tile shape
wg_coord, // Use workgroup coordinates
tiled_mma,
mn_shape,
params.xe_store_d,
copy_d,
cD,
residue_mn,
tRS_cD,
tRS_cD_coord,
residue_mn,
trC,
thread_idx,
};
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
auto synchronize = [&] () {};

cst_callbacks.begin();
// Load C tile if needed (distributed across all threads in workgroup)
if (is_C_load_needed) {
copy(copy_c, tCgC, trC);
}

auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);

constexpr int ValuesLoaded =
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );
// Single previsit for entire workgroup tile
cst_callbacks.previsit(0, 0, 0, is_C_load_needed);

auto synchronize = [&] () {};
static constexpr int FragmentSize = get<0>(MmaAtomShape()) * get<1>(MmaAtomShape());
constexpr int num_fragments = size(accumulators) / FragmentSize;

CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
for (int epi_v = 0; epi_v < num_fragments; ++epi_v) {
// Extract fragment
Array<ElementAccumulator, FragmentSize> frg_acc;
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
cst_callbacks.begin_loop(epi_m, epi_n);

if (is_C_load_needed) {
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
for (int f = 0; f < FragmentSize; ++f) {
frg_acc[f] = accumulators(epi_v * FragmentSize + f);
}

// Process fragment
auto result_frg = cst_callbacks.visit(frg_acc, epi_v, 0, 0);

// Store results
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
trD_compute(epi_v * FragmentSize + f) = result_frg[f];
}
}

auto acc_frag_mn = acc_frag(_, epi_m, epi_n);
cst_callbacks.reduce(nullptr, synchronize, 0, 0, true, trD_compute);

if constexpr (is_destination_supported) {
// Convert fragments using NumericArrayConverter
constexpr int num_fragments_trD_compute = size(trD_compute) / FragmentSize;
using Converter = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>;
Converter converter{};

CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < num_fragments_trD_compute; ++epi_v) {
// Extract compute fragment
Array<ElementCompute, FragmentSize> trD_compute_frag;
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
for (int f = 0; f < FragmentSize; ++f) {
trD_compute_frag[f] = trD_compute(epi_v * FragmentSize + f);
}
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);

if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
}
// Convert fragment
auto trD_frag = converter(trD_compute_frag);

cst_callbacks.end_loop(epi_m, epi_n);
// Store converted fragment
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
trD(epi_v * FragmentSize + f) = trD_frag[f];

}
}
}

copy(copy_d, trD, tCgD);
}

cst_callbacks.end();
}
cst_callbacks.end();

}

private:
Params const& params;
Expand All @@ -447,4 +434,4 @@ class CollectiveEpilogue<
} // namespace epilogue
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
Loading
Loading