diff --git a/benchmarks/gemm/benchmarks_sycl.hpp b/benchmarks/gemm/benchmarks_sycl.hpp index ee28c1cfce..cd7c703df2 100644 --- a/benchmarks/gemm/benchmarks_sycl.hpp +++ b/benchmarks/gemm/benchmarks_sycl.hpp @@ -319,12 +319,13 @@ using PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1 = cutlass::gemm::device::Mi cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -335,12 +336,13 @@ using PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1 = cutlass::gemm::device::M cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -351,28 +353,29 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 >; - using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< cutlass::arch::IntelXe, cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::int8_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -383,12 +386,13 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -399,12 +403,13 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::int8_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -415,12 +420,13 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::int8_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -431,12 +437,13 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix cutlass::half_t, cutlass::layout::RowMajor, cutlass::int8_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 diff --git a/benchmarks/gemm/gemm_configuration_sycl.hpp b/benchmarks/gemm/gemm_configuration_sycl.hpp index 0576e33382..08cdb32875 100644 --- a/benchmarks/gemm/gemm_configuration_sycl.hpp +++ b/benchmarks/gemm/gemm_configuration_sycl.hpp @@ -70,6 +70,7 @@ template< class ArchTag, class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, typename LayoutC, + class ElementD, typename LayoutD, class ElementScale, typename StrideS, class ElementZero, typename StrideZ, class TileShape, Scheduler TileScheduler, @@ -175,6 +176,7 @@ struct GemmConfiguration< template { - using LayoutD = LayoutC; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; @@ -205,7 +207,7 @@ struct MixedPrecisionGemmConfiguration< TileShape, ElementAccumulator, cutlass::gemm::TagToStrideC_t, - ElementC, + ElementD, cutlass::gemm::TagToStrideC_t, FusionCallBacks, XE_2D_U32x8x16_LD_N, diff --git a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp index 335976d8ac..5c8648eb54 100755 --- a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp +++ b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp @@ -263,7 +263,7 @@ struct ExampleRunner { using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks>, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -411,7 +411,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -445,7 +445,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); } diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 463498e8ca..f55b66f7b0 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -389,7 +389,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -423,7 +423,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); } diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index a851e6110f..fa6b1bc5ac 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -1465,7 +1465,7 @@ struct DefaultGemmConfigurationToCutlass3Types< cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16, @@ -1567,7 +1567,7 @@ struct DefaultGemmConfigurationToCutlass3Types< cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16, diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index 9dd5ab06f0..cccfa3dbf6 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration< using TiledMma = typename CollectiveMainloop::TiledMma; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16Group,