Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions benchmarks/gemm/benchmarks_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,13 @@ using PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1 = cutlass::gemm::device::Mi
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
cutlass::half_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think it's better to add copy_traits for both C and D for different data type. XE_2D_U16x8x16_ST_N here
    means storing to D as F16, you shouldn't change it. If you want to store D as F32, you should add a new
    GemmConfiguration for it.
  2. can you refine the data type name in "FP16U4FP16F16FP16S4" according to your changes.
    The data type " FP16U4FP16F16FP16S4 " in the name are: A, B, C, Mma, Scale, Zero

cutlass::epilogue::fusion::LinearCombination<float, float,
float, float, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -335,12 +336,13 @@ using PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1 = cutlass::gemm::device::M
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<float, float,
float, float, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -351,28 +353,29 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::half_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration<
cutlass::arch::IntelXe,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::int8_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::half_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -383,12 +386,13 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -399,12 +403,13 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::uint4_t, cutlass::layout::ColumnMajor,
cutlass::int8_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -415,12 +420,13 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::int8_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand All @@ -431,12 +437,13 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::int8_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
int, cutlass::layout::RowMajor,
cutlass::half_t, cute::Stride<_1, int64_t, int64_t>,
cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>,
Shape<_32, _128, _32>, Scheduler::Gemm,
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
2
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/gemm/gemm_configuration_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ template<
class ArchTag,
class ElementA, class LayoutA,
class ElementB, class LayoutB, class ElementC, typename LayoutC,
class ElementD, typename LayoutD,
class ElementScale, typename StrideS,
class ElementZero, typename StrideZ,
class TileShape, Scheduler TileScheduler,
Expand Down Expand Up @@ -175,6 +176,7 @@ struct GemmConfiguration<
template<class ElementA, class LayoutA,
class ElementB, class LayoutB,
class ElementC, typename LayoutC,
class ElementD, typename LayoutD,
class ElementScale, typename StrideS,
class ElementZero, typename StrideZ,
class TileShape, Scheduler TileScheduler,
Expand All @@ -185,13 +187,13 @@ struct MixedPrecisionGemmConfiguration<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementD, LayoutD,
ElementScale, StrideS,
ElementZero, StrideZ,
TileShape, TileScheduler, TiledMma,
GmemTiledCopyA, GmemTiledCopyB,
GmemTiledCopyC, EpilogueOp, Stages>
{
using LayoutD = LayoutC;

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<Stages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
Expand All @@ -205,7 +207,7 @@ struct MixedPrecisionGemmConfiguration<
TileShape,
ElementAccumulator,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be ElementC, right?

cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementC,
ElementD,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
Expand Down
4 changes: 2 additions & 2 deletions examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down Expand Up @@ -648,7 +648,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
4 changes: 2 additions & 2 deletions examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down Expand Up @@ -735,7 +735,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ int main(int argc, const char** argv) {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ struct ExampleRunner {
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CollectiveEpilogue<
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;

static_assert(cute::is_same_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
"Only Linear Combination Epilogue is supported for Grouped GEMM at the moment.");

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand Down Expand Up @@ -411,7 +411,7 @@ class CollectiveEpilogue<
cst_callbacks.begin();

auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
auto trD_compute_frag = recast<Array<ElementOutput, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
Expand Down Expand Up @@ -445,7 +445,7 @@ class CollectiveEpilogue<
if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementOutput, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n));
}
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class CollectiveEpilogue<
cst_callbacks.begin();

auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
auto trD_compute_frag = recast<Array<ElementOutput, FragmentSize>>(trD_compute);

Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
Expand Down Expand Up @@ -423,7 +423,7 @@ class CollectiveEpilogue<
if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementOutput, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
}
Expand Down
4 changes: 2 additions & 2 deletions test/unit/gemm/device/default_gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ struct DefaultGemmConfigurationToCutlass3Types<
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16,
Expand Down Expand Up @@ -1567,7 +1567,7 @@ struct DefaultGemmConfigurationToCutlass3Types<
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16,
Expand Down
2 changes: 1 addition & 1 deletion test/unit/gemm/device/default_gemm_group_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration<

using TiledMma = typename CollectiveMainloop::TiledMma;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16Group,
Expand Down
Loading