Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ int main(int argc, const char** argv)
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
using GmemTiledCopyC = void; //XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = void; //XE_STORE_2D<32, 8, 16>;
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;



Expand Down
170 changes: 101 additions & 69 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ class CollectiveEpilogue<
return fusion_callbacks.is_producer_load_needed();
}

template<typename Tensor>
CUTLASS_DEVICE auto reshape_into_smaller_fragments(Tensor&& tensor) {
using namespace cute;

auto target_stride = make_stride(
make_stride(cute::ScaledBasis<cute::Int<1>, 0>{}, _0{}),
cute::ScaledBasis<cute::Int<8>, 0>{},
cute::ScaledBasis<cute::Int<16>, 1>{}
);

auto target_layout = make_layout(
make_shape(make_shape(_8{}, _1{}), _4{}, _4{}),
target_stride
);

return make_tensor(tensor.data(), target_layout);
}

template<
class ProblemShapeMNKL,
class TileShapeMNK,
Expand All @@ -283,25 +301,52 @@ class CollectiveEpilogue<
static constexpr auto BLK_M = get<0>(CtaTileMNK{});
static constexpr auto BLK_N = get<1>(CtaTileMNK{});
static constexpr auto BLK_K = get<2>(CtaTileMNK{});
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());

static_assert(
BLK_M % ATOM_M == 0 &&
BLK_N % ATOM_N == 0 &&
BLK_K % ATOM_K == 0,
"expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK");
static constexpr auto SG_M = BLK_M / ATOM_M;
static constexpr auto SG_N = BLK_N / ATOM_N;
static constexpr auto SG_K = BLK_K / ATOM_K;
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;

static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group

static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;

// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto m_sg = get_sub_group_id() / ATOM_N;
auto n_sg = get_sub_group_id() % ATOM_N;

auto sg_local_m_coord = get_sub_group_id() / ATOM_N;
auto sg_local_n_coord = get_sub_group_id() % ATOM_N;

auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);

// Workgroup coordinates (no subgroup indexing needed)
auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord);

bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

auto batch_idx = get<3>(wg_coord);
auto copy_c = get_block_2d_copy_C<CopyOpG2R>(tiled_mma, params.mC(_,_,batch_idx));
auto copy_d = get_block_2d_copy_D<CopyOpR2G>(tiled_mma, params.mD(_,_,batch_idx));



// Represent the full output tensor
Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L));

// Tile the output tensor for the current workgroup
Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) // change made
Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N)

// Get thread-level partitioning across the entire workgroup tile
auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx);
Expand All @@ -310,105 +355,92 @@ class CollectiveEpilogue<
auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);

// Create tensors sized for workgroup-level operation
Tensor trC = make_tensor<typename TiledMma::ValTypeC>(tCgC.shape());
Tensor trD_compute = make_tensor<ElementCompute>(tCgD.shape());
Tensor trD = make_tensor<ElementOutput>(tCgD.shape());
auto tCgC_frag = reshape_into_smaller_fragments(tCgC);
auto tCgD_frag = reshape_into_smaller_fragments(tCgD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});
ThrCopy thread_g2r = copy_c.get_slice(thread_idx);

auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{});

// OOB predication for tile quantization "residue"
Tensor mD_crd = make_identity_tensor(make_shape(M,N));
Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord));
Tensor tRS_cD = thread_g2r.partition_S(flat_divide(cD, mn_shape));
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord));
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)

Tensor tRS_cD_coord = make_coord_tensor(tRS_cD.layout());
Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout());

// Get fusion callbacks at workgroup level
// Get the fusion callbacks
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N);
auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
CtaTileMNK{}, // Use workgroup tile shape
wg_coord, // Use workgroup coordinates
SubgroupTileShape{},
sg_coord,
tiled_mma,
mn_shape,
copy_d,
cD,
residue_mn,
tRS_cD_coord,
tRS_cD,
residue_mn,
trC,
thread_idx,
};
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
auto synchronize = [&] () {};

cst_callbacks.begin();
// Load C tile if needed (distributed across all threads in workgroup)
if (is_C_load_needed) {
copy(copy_c, tCgC, trC);
}

// Single previsit for entire workgroup tile
cst_callbacks.previsit(0, 0, 0, is_C_load_needed);

static constexpr int FragmentSize = get<0>(MmaAtomShape()) * get<1>(MmaAtomShape());
constexpr int num_fragments = size(accumulators) / FragmentSize;
auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);

constexpr int ValuesLoaded =
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );


auto synchronize = [&] () {};
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < num_fragments; ++epi_v) {
// Extract fragment
Array<ElementAccumulator, FragmentSize> frg_acc;
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
frg_acc[f] = accumulators(epi_v * FragmentSize + f);
}

// Process fragment
auto result_frg = cst_callbacks.visit(frg_acc, epi_v, 0, 0);

// Store results
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
trD_compute(epi_v * FragmentSize + f) = result_frg[f];
}
}
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
cst_callbacks.begin_loop(epi_m, epi_n);

cst_callbacks.reduce(nullptr, synchronize, 0, 0, true, trD_compute);
if (is_C_load_needed) {
copy(copy_c, tCgC_frag(_, epi_m, epi_n), trC);
}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);

auto acc_frag_mn = acc_frag(_, epi_m, epi_n);

if constexpr (is_destination_supported) {
// Convert fragments using NumericArrayConverter
constexpr int num_fragments_trD_compute = size(trD_compute) / FragmentSize;
using Converter = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>;
Converter converter{};

CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < num_fragments_trD_compute; ++epi_v) {
// Extract compute fragment
Array<ElementCompute, FragmentSize> trD_compute_frag;
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
trD_compute_frag[f] = trD_compute(epi_v * FragmentSize + f);
for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
}

// Convert fragment
auto trD_frag = converter(trD_compute_frag);

// Store converted fragment
CUTLASS_PRAGMA_UNROLL
for (int f = 0; f < FragmentSize; ++f) {
trD(epi_v * FragmentSize + f) = trD_frag[f];

cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);

if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
}
copy(copy_d, trD, tCgD_frag(_, epi_m, epi_n));
}

cst_callbacks.end_loop(epi_m, epi_n);

}

copy(copy_d, trD, tCgD);
}
}

cst_callbacks.end();
cst_callbacks.end();

}

Expand Down
Loading