Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
83f1de8
Test commit
Sep 29, 2025
0b184f0
Enable new mma and copy atoms
Sep 29, 2025
ef1bafa
adding legacy code back for collectivemma and gemmuniversal
Sep 30, 2025
f210ba3
delete unwanted file
Sep 30, 2025
5f5a8b7
Changes added based on feedback
Oct 1, 2025
c55ac28
Remove xe_gemm_legacy as its not longer used
Oct 1, 2025
946b46c
Changes added based on feedback
Oct 3, 2025
c97f011
Applied review comments
Oct 4, 2025
9691e60
Add compile-time checks to enforce new XE copy atoms in block 2D func…
Oct 6, 2025
93b076a
Modified static assert message
Oct 6, 2025
a6f068c
Modified static assert message
Oct 6, 2025
fcbfecf
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 6, 2025
e1e64f7
Move legacy example to legacy folder, pass 2D strides to make_block_2…
Oct 8, 2025
ea67069
Applied reviwer comment
Oct 10, 2025
e9878b9
This is an empty commit
Oct 10, 2025
fbb7bb5
Preventing exceptions on older IGC versions
anamikac-intel Oct 10, 2025
4fb70c0
Remove unwanted returns from device-side params
Oct 12, 2025
4fd4376
Modify compile-time checks to enforce new XE copy atoms in block 2D f…
Oct 13, 2025
ca503bf
Applied review comments
Oct 17, 2025
4eb3bf3
Add batch_idx to global tensor passed to make_block_2d_copy_* and Blo…
Oct 19, 2025
07aa4c8
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 20, 2025
800480a
Added comments on why batch indexing used for make_block_2d_copy_*
Oct 20, 2025
018ffb8
Merge branch 'intel:main' into anamikac/add-newatoms
anamikac-intel Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,58 @@

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////

// Helper template to check if a type is complete
template <typename T, size_t = 0>
struct is_complete : std::false_type {};

template <typename T>
struct is_complete<T, 0 * sizeof(T)> : std::true_type {};

template <typename T>
static constexpr bool is_complete_v = is_complete<T>::value;


template <typename TA, typename TB, typename TC>
auto
choose_mma_op()
{
if constexpr (is_complete_v<XE_DPAS_TT<8, TC, TA, TB>>)
return XE_DPAS_TT<8, TC, TA, TB>{};
else if constexpr (is_same_v<std::decay_t<TA>, cute::bfloat16_t>)
return XE_DPAS_TT<8, float, cute::bfloat16_t>{};
else /* Use f16 by default as upconversion sequences are typically faster */
return XE_DPAS_TT<8, float, cute::half_t>{};
}

// Helper function to choose tiled MMA based on tensor properties
template <typename TA, typename TB, typename TC, typename LayoutA, typename LayoutB>
auto
choose_tiled_mma()
{

auto op = choose_mma_op<TA,TB,TC>();

constexpr bool byte = (cute::max(sizeof_bits_v<TA>, sizeof_bits_v<TB>) <= 8);

constexpr bool is_A_transposed = std::is_same_v<LayoutA, cutlass::layout::ColumnMajor>;
constexpr bool is_B_transposed = std::is_same_v<LayoutB, cutlass::layout::ColumnMajor>;
constexpr bool use_1x_dpas_per_k = is_A_transposed || (byte && is_B_transposed);


using _K = conditional_t<use_1x_dpas_per_k,
C<op.K>, C<op.K*2>>;

using WGTile = Shape<_256, _256, _K>; // 256x256 WG tile size
using SGLayout = Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major

using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>, SGLayout>::TiledMMA;

return MMA{};
}


///////////////////////////////////////////////////////////////////////////////////////////////////

// Command line options parsing
Expand Down Expand Up @@ -350,21 +402,8 @@ int main(int argc, const char** argv)
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
using TiledMma = decltype(choose_tiled_mma<ElementInputA, ElementInputB, ElementOutput, LayoutA, LayoutB>());
using TileShape = decltype(TiledMma{}.tile_mnk());

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
Expand Down Expand Up @@ -398,7 +437,7 @@ int main(int argc, const char** argv)
void, void>;

// GEMM Mainloop - iteration over blocks in K dimension
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMmaNew<
GEMMDispatchPolicy,
TileShape,
ElementInputA,
Expand All @@ -411,7 +450,7 @@ int main(int argc, const char** argv)
>;

// Define the whole kernel (mainloop and epilogue)
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = cutlass::gemm::kernel::GemmUniversalNew<
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
CollectiveMainloop,
CollectiveEpilogue
Expand Down
Loading