Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ template <
using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
DispatchPolicy,
TileShape_MNK,
ElementAccumulator,
ElementC,
StrideC,
ElementD,
StrideD,
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class CollectiveEpilogue<
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
using ElementAccumulator = ElementC_;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
Expand All @@ -109,7 +108,8 @@ class CollectiveEpilogue<
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
using ElementOutput = ElementD;
using ElementCompute = ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementAccumulator = ElementCompute;
using ElementSource = typename FusionCallbacks::ElementSource;
using ElementScalar = typename FusionCallbacks::ElementScalar;
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
Expand Down Expand Up @@ -371,7 +371,7 @@ class CollectiveEpilogue<
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trC = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
Expand Down
7 changes: 3 additions & 4 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class CollectiveEpilogue<
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
using ElementAccumulator = ElementC_;
using StrideC = StrideC_;
using ElementD = ElementD_;
using StrideD = StrideD_;
Expand All @@ -106,8 +105,8 @@ class CollectiveEpilogue<
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
using ElementOutput = ElementD;
using ElementCompute = ElementAccumulator;

using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementAccumulator = ElementCompute;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
Expand Down Expand Up @@ -349,7 +348,7 @@ class CollectiveEpilogue<
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trC = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
Expand Down
Loading