Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down Expand Up @@ -648,7 +648,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
4 changes: 2 additions & 2 deletions examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down Expand Up @@ -735,7 +735,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CollectiveEpilogue<
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;

static_assert(cute::is_same_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
"Only Linear Combination Epilogue is supported for Grouped GEMM at the moment.");

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand Down Expand Up @@ -411,7 +411,7 @@ class CollectiveEpilogue<
cst_callbacks.begin();

auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
auto trD_compute_frag = recast<Array<ElementOutput, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
Expand Down Expand Up @@ -445,7 +445,7 @@ class CollectiveEpilogue<
if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementOutput, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n));
}
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class CollectiveEpilogue<
cst_callbacks.begin();

auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
auto trD_compute_frag = recast<Array<ElementOutput, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
Expand Down Expand Up @@ -423,7 +423,7 @@ class CollectiveEpilogue<
if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementOutput, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
}
Expand Down
4 changes: 2 additions & 2 deletions test/unit/gemm/device/default_gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ struct DefaultGemmConfigurationToCutlass3Types<
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16,
Expand Down Expand Up @@ -1567,7 +1567,7 @@ struct DefaultGemmConfigurationToCutlass3Types<
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16,
Expand Down
2 changes: 1 addition & 1 deletion test/unit/gemm/device/default_gemm_group_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration<

using TiledMma = typename CollectiveMainloop::TiledMma;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16Group,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct XE_Device_Gemm_fp16_fp16_f16_tensor_op_f32 {
cute::half_t, LayoutA,
cute::half_t, LayoutB,
float, layout::RowMajor,
cute::half_t>;
float>;

using Gemm = gemm::device::GemmUniversalAdapter<gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
Expand Down