diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 809cede6f7..12c36b45c6 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -198,7 +198,7 @@ template < using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< DispatchPolicy, TileShape_MNK, - ElementAccumulator, + ElementC, StrideC, ElementD, StrideD, diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 9b879bd14d..47bc0e582e 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -91,7 +91,6 @@ class CollectiveEpilogue< using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; - using ElementAccumulator = ElementC_; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; using ElementD = ElementD_; @@ -109,7 +108,8 @@ class CollectiveEpilogue< using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, CopyOpR2G, XE_2D_U32x8x16_ST_N>; using ElementOutput = ElementD; - using ElementCompute = ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementAccumulator = ElementCompute; using ElementSource = typename FusionCallbacks::ElementSource; using ElementScalar = typename FusionCallbacks::ElementScalar; static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; @@ -371,7 +371,7 @@ class CollectiveEpilogue< auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); - Tensor trC = make_tensor(Shape>{}); + Tensor trC = make_tensor(Shape>{}); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 05437ab180..b7bf5bf89e 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -90,7 +90,6 @@ class CollectiveEpilogue< using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; - using ElementAccumulator = ElementC_; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; @@ -106,8 +105,8 @@ class CollectiveEpilogue< using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, CopyOpR2G, XE_2D_U32x8x16_ST_N>; using ElementOutput = ElementD; - using ElementCompute = ElementAccumulator; - + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementAccumulator = ElementCompute; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); @@ -349,7 +348,7 @@ class CollectiveEpilogue< auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); - Tensor trC = make_tensor(Shape>{}); + Tensor trC = make_tensor(Shape>{}); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values