Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
83f1de8
Test commit
Sep 29, 2025
0b184f0
Enable new mma and copy atoms
Sep 29, 2025
ef1bafa
adding legacy code back for collectivemma and gemmuniversal
Sep 30, 2025
f210ba3
delete unwanted file
Sep 30, 2025
5f5a8b7
Changes added based on feedback
Oct 1, 2025
c55ac28
Remove xe_gemm_legacy as its not longer used
Oct 1, 2025
946b46c
Changes added based on feedback
Oct 3, 2025
c97f011
Applied review comments
Oct 4, 2025
9691e60
Add compile-time checks to enforce new XE copy atoms in block 2D func…
Oct 6, 2025
93b076a
Modified static assert message
Oct 6, 2025
a6f068c
Modified static assert message
Oct 6, 2025
fcbfecf
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 6, 2025
e1e64f7
Move legacy example to legacy folder, pass 2D strides to make_block_2…
Oct 8, 2025
ea67069
Applied reviwer comment
Oct 10, 2025
e9878b9
This is an empty commit
Oct 10, 2025
fbb7bb5
Preventing exceptions on older IGC versions
anamikac-intel Oct 10, 2025
4fb70c0
Remove unwanted returns from device-side params
Oct 12, 2025
4fd4376
Modify compile-time checks to enforce new XE copy atoms in block 2D f…
Oct 13, 2025
ca503bf
Applied review comments
Oct 17, 2025
4eb3bf3
Add batch_idx to global tensor passed to make_block_2d_copy_* and Blo…
Oct 19, 2025
07aa4c8
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 20, 2025
800480a
Added comments on why batch indexing used for make_block_2d_copy_*
Oct 20, 2025
018ffb8
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,30 +345,24 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses
// the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1).
// The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
Expand Down Expand Up @@ -398,6 +392,13 @@ int main(int argc, const char** argv)
void, void>;

// GEMM Mainloop - iteration over blocks in K dimension
//
// Copy operations for A and B matrices:
// - Use 'void' (as shown below) to automatically select new 2D block copy operations
// - To use legacy copy operations, replace 'void' with specific copy atoms, e.g.:
// using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
// using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// Then replace the first 'void' with GmemTiledCopyA and fifth 'void' with GmemTiledCopyB
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
Expand All @@ -406,8 +407,8 @@ int main(int argc, const char** argv)
ElementInputB,
cutlass::gemm::TagToStrideB_t<LayoutB>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
TiledMma,
GmemTiledCopyA, void, void, cute::identity, // A
GmemTiledCopyB, void, void, cute::identity // B
void, void, void, cute::identity, // A
void, void, void, cute::identity // B
>;

// Define the whole kernel (mainloop and epilogue)
Expand Down
1 change: 1 addition & 0 deletions include/cutlass/gemm/collective/collective_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
#endif // !defined(__CUDACC_RTC__)

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/collective/xe_mma_legacy.hpp"
#include "cutlass/gemm/collective/xe_mma.hpp"
#include "cutlass/gemm/collective/xe_array_mma.hpp"
#include "cutlass/gemm/collective/xe_array_mma_fp8.hpp"
Expand Down
139 changes: 87 additions & 52 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ using namespace cute;
template <int Stages, class Schedule, class TileShape_, class ElementA_, class StrideA_, class ElementB_, class StrideB_,
class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_, class TransformB_>
struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
struct CollectiveMma<MainloopXeL1Staged<Stages, Schedule>, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_,
SmemCopyAtomB_, TransformB_> {
//
// Type Aliases
//
using DispatchPolicy = MainloopIntelXeXMX16<Stages, Schedule>;
using DispatchPolicy = MainloopXeL1Staged<Stages, Schedule>;
using WorkgroupTileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
Expand All @@ -71,7 +71,7 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;

static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopIntelXeXMX16 requires that A and B have same type.");
static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopXeL1Staged requires that A and B have same type.");
static_assert(std::is_same_v<TransformA, cute::identity>, "Transformation for A is not currently supported on Intel PVC");
static_assert(std::is_same_v<TransformB, cute::identity>, "Transformation for B is not currently supported on Intel PVC");

Expand Down Expand Up @@ -100,9 +100,6 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K;
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});

using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;

// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
Expand All @@ -112,8 +109,11 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
};

struct Params {
Copy_A tiled_copy_a;
Copy_B tiled_copy_b;
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
int M, N, K, L;
};

//
Expand All @@ -129,12 +129,11 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element

auto [M,N,K,L] = problem_shape;

auto mA_mkl = make_tensor(make_gmem_ptr(args.ptr_A), make_layout(make_shape(M, K, L), args.dA));
auto mB_nkl = make_tensor(make_gmem_ptr(args.ptr_B), make_layout(make_shape(N, K, L), args.dB));
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};

return Params{tiled_copy_a, tiled_copy_b};
return Params{args.ptr_A,
args.dA,
args.ptr_B,
args.dB,
M, N, K, L};
}

template<class ProblemShape>
Expand Down Expand Up @@ -177,59 +176,94 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
static_assert(is_rmem<FrgTensorD>::value, "D tensor must be rmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx);
auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx);
auto copy_a = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyA>) {
// User provided copy operation - use full stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), mainloop.dA));
using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
return Copy_A{}.with(mA_mkl);
} else {
// Use new 2D copy operations with 2D stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA)));
return make_block_2d_copy_A(TiledMma{}, mA_mkl);
}
}();

auto copy_b = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyB>) {
// User provided copy operation - use full stride
auto mB_nkl = make_tensor(make_gmem_ptr(mainloop.ptr_B),
make_layout(make_shape(mainloop.N, mainloop.K, mainloop.L), mainloop.dB));
using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;
return Copy_B{}.with(mB_nkl);
} else {
// Use new 2D copy operations with 2D stride
auto mB_nkl = make_tensor(make_gmem_ptr(mainloop.ptr_B),
make_layout(make_shape(mainloop.N, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dB)));
return make_block_2d_copy_B(TiledMma{}, mB_nkl);
}
}();

auto thr_copy_a = copy_a.get_slice(thread_idx);
auto thr_copy_b = copy_b.get_slice(thread_idx);

// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
// TODO(Codeplay): see if we can make this nicer
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

// Partition global counting tensors for MMA
Tensor tCgA = thr_mma.partition_A(gA);
Tensor tCgB = thr_mma.partition_B(gB);

Tensor tCrA = make_tensor<ElementA>(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape()));
Tensor tCrB = make_tensor<ElementB>(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape()));

// Retile registers for copies
Tensor tArA = thr_copy_A.retile_D(tCrA);
Tensor tBrB = thr_copy_B.retile_D(tCrB);

// Retile global counting tensors for copies
Tensor tAgA = thr_copy_A.retile_S(tCgA);
Tensor tBgB = thr_copy_B.retile_S(tCgB);

auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_a);
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_b);
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
auto thr_mma = tiled_mma.get_slice(thread_idx);

/* Register fragments for MMA */
auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0));
auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0));

/* Register fragments for copies */
auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0));
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0));

/* Partition global tensor (proxies) for copies */
Tensor tAgA = thr_copy_a.partition_S(gA);
Tensor tBgB = thr_copy_b.partition_S(gB);

// Partition global tile for prefetch
/* Create prefetch TiledCopy instances - different for legacy vs new copy operations */
auto [prefetch_a, prefetch_b, thr_prefetch_A, thr_prefetch_B] = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyA> && !std::is_void_v<GmemTiledCopyB>) {
// Legacy copy operations - use prefetch_selector
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(copy_a);
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(copy_b);
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
return std::make_tuple(tiled_prefetch_a, tiled_prefetch_b, thr_prefetch_A, thr_prefetch_B);
} else {
// New 2D copy operations - use make_block_2d_prefetch
auto prefetch_a = make_block_2d_prefetch(copy_a);
auto prefetch_b = make_block_2d_prefetch(copy_b);
auto thr_prefetch_A = prefetch_a.get_slice(thread_idx);
auto thr_prefetch_B = prefetch_b.get_slice(thread_idx);
return std::make_tuple(prefetch_a, prefetch_b, thr_prefetch_A, thr_prefetch_B);
}
}();

/* Partition global tensor (proxies) for prefetch */
auto pAgA = thr_prefetch_A.partition_S(gA);
auto pBgB = thr_prefetch_B.partition_S(gB);

#if CUTLASS_ENABLE_DEBUG_PRINTS
#define PRINT(x) print(#x ": "); print(x); print("\n");
if (cute::thread(LOG_THREAD, LOG_GROUP)) {
print("======================= A: \n");
PRINT(tCgA);
PRINT(tAgA);

PRINT(tCrA);
PRINT(tArA);
PRINT(mainloop.tiled_copy_a);
PRINT(copy_a);

print("======================= B: \n");
PRINT(tCgB);
PRINT(tBgB);

PRINT(tCrB);
PRINT(tBrB);
PRINT(mainloop.tiled_copy_b);
PRINT(copy_b);
}
#undef PRINT
#endif
Expand All @@ -243,19 +277,19 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element

CUTLASS_PRAGMA_UNROLL
for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) {
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
prefetch(prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
}

for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
barrier_arrive(barrier_scope);
// Copy gmem to rmem for the first k_tile
copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA);
copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB);
copy(copy_a, tAgA(_,_,_,k_tile), tArA);
copy(copy_b, tBgB(_,_,_,k_tile), tBrB);

if (prefetch_k < k_tile_count) {
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
prefetch(prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
}

cute::gemm(tiled_mma, tCrA, tCrB, accum);
Expand All @@ -267,3 +301,4 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
} // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////

Loading