Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ int main(int argc, const char** argv)
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
using GmemTiledCopyC = void; //XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = void; //XE_STORE_2D<32, 8, 16>;



// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand All @@ -369,9 +373,8 @@ int main(int argc, const char** argv)

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeL1Staged;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
// (D = alpha * (A*B) + beta * C)
Expand All @@ -394,9 +397,9 @@ int main(int argc, const char** argv)
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
FusionCallBacks,
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
GmemTiledCopyC, // The copy atom used to load matrix C
void, void,
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
GmemTiledCopyD, // The copy atom used to store matrix D
void, void>;

// GEMM Mainloop - iteration over blocks in K dimension
Expand Down
12 changes: 11 additions & 1 deletion include/cute/atom/copy_traits_xe_2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,12 +1143,22 @@ template <class CopyOp, class TiledMMA, class CTensor>
auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor)
{
if constexpr (!std::is_void_v<CopyOp>) {
return make_block_2d_copy_C(CopyOp{}, tiled_mma, c_tensor);
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor);
} else {
return make_block_2d_copy_C(tiled_mma, c_tensor);
}
}

template <class CopyOp, class TiledMMA, class DTensor>
auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor)
{
if constexpr (!std::is_void_v<CopyOp>) {
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor);
} else {
return make_block_2d_copy_D(tiled_mma, d_tensor);
}
}

//
// Display utilities
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class CollectiveEpilogue {
#include "sm100_epilogue_array_tma_warpspecialized.hpp"
#if defined (SYCL_INTEL_TARGET)
#include "xe_epilogue.hpp"
#include "xe_epilogue_legacy.hpp"
#include "xe_array_epilogue.hpp"
#endif
//
Expand Down
Loading
Loading