From ef9485020abe4060c3017e13df95a405cf53f628 Mon Sep 17 00:00:00 2001 From: "Joy, Albin" Date: Mon, 13 Oct 2025 07:55:34 +0000 Subject: [PATCH 1/5] Support different dtype for LinearCombination Fix for supporting different data type for LinearCombination Output and Compute. --- examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp | 4 ++-- .../10_bmg_grouped_gemm_f16_u4.cpp | 2 +- include/cutlass/epilogue/collective/xe_epilogue.hpp | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp index 335976d8ac..5c8648eb54 100755 --- a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp +++ b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp @@ -263,7 +263,7 @@ struct ExampleRunner { using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -423,7 +423,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); } From 64e777ae7c72cfd337f2aeba779c8d9f5af3e7e9 Mon Sep 17 00:00:00 2001 From: "Joy, Albin" Date: Fri, 17 Oct 2025 12:38:23 +0000 Subject: [PATCH 2/5] Correct usage of LinearCombination with proper inputs. --- examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp | 4 ++-- .../bmg_grouped_gemm_mixed_dtype_runner.hpp | 2 +- include/cutlass/epilogue/collective/xe_array_epilogue.hpp | 6 +++--- include/cutlass/epilogue/collective/xe_epilogue.hpp | 4 ++-- test/unit/gemm/device/default_gemm_configuration.hpp | 4 ++-- test/unit/gemm/device/default_gemm_group_configuration.hpp | 2 +- .../gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp index efd355eeff..a1975f607e 100755 --- a/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp +++ b/examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp @@ -258,7 +258,7 @@ struct ExampleRunner { using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks>, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -411,7 +411,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -445,7 +445,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); } diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index f55b66f7b0..463498e8ca 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -389,7 +389,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -423,7 +423,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); } diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index a851e6110f..fa6b1bc5ac 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -1465,7 +1465,7 @@ struct DefaultGemmConfigurationToCutlass3Types< cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16, @@ -1567,7 +1567,7 @@ struct DefaultGemmConfigurationToCutlass3Types< cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16, diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index 9dd5ab06f0..cccfa3dbf6 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration< using TiledMma = typename CollectiveMainloop::TiledMma; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16Group, diff --git a/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp b/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp index 68934691a4..d061b33343 100644 --- a/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp +++ b/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp @@ -51,7 +51,7 @@ struct XE_Device_Gemm_fp16_fp16_f16_tensor_op_f32 { cute::half_t, LayoutA, cute::half_t, LayoutB, float, layout::RowMajor, - cute::half_t>; + float>; using Gemm = gemm::device::GemmUniversalAdapter, From 528c1c5bbb098a5cc4800820bfb60ae2085a3ab2 Mon Sep 17 00:00:00 2001 From: "Joy, Albin" Date: Fri, 17 Oct 2025 13:08:15 +0000 Subject: [PATCH 3/5] Fix merge conflict --- include/cutlass/epilogue/collective/xe_epilogue.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 463498e8ca..f55b66f7b0 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -389,7 +389,7 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -423,7 +423,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); } From cfac845ea5e3837458f932e41464e5d79da715f7 Mon Sep 17 00:00:00 2001 From: "Joy, Albin" Date: Wed, 22 Oct 2025 12:38:42 +0000 Subject: [PATCH 4/5] Fix benchmarks MixedPrecision testcases --- benchmarks/gemm/benchmarks_sycl.hpp | 25 +++++++++++++-------- benchmarks/gemm/gemm_configuration_sycl.hpp | 6 +++-- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/benchmarks/gemm/benchmarks_sycl.hpp b/benchmarks/gemm/benchmarks_sycl.hpp index ee28c1cfce..cd7c703df2 100644 --- a/benchmarks/gemm/benchmarks_sycl.hpp +++ b/benchmarks/gemm/benchmarks_sycl.hpp @@ -319,12 +319,13 @@ using PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1 = cutlass::gemm::device::Mi cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -335,12 +336,13 @@ using PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1 = cutlass::gemm::device::M cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -351,28 +353,29 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 >; - using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< cutlass::arch::IntelXe, cutlass::half_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::int8_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -383,12 +386,13 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -399,12 +403,13 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::uint4_t, cutlass::layout::ColumnMajor, cutlass::int8_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -415,12 +420,13 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::int8_t, cutlass::layout::ColumnMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 @@ -431,12 +437,13 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix cutlass::half_t, cutlass::layout::RowMajor, cutlass::int8_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, + int, cutlass::layout::RowMajor, cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, Shape<_32, _128, _32>, Scheduler::Gemm, typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, - XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U32x8x16_ST_N, cutlass::epilogue::fusion::LinearCombination, 2 diff --git a/benchmarks/gemm/gemm_configuration_sycl.hpp b/benchmarks/gemm/gemm_configuration_sycl.hpp index 0576e33382..08cdb32875 100644 --- a/benchmarks/gemm/gemm_configuration_sycl.hpp +++ b/benchmarks/gemm/gemm_configuration_sycl.hpp @@ -70,6 +70,7 @@ template< class ArchTag, class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, typename LayoutC, + class ElementD, typename LayoutD, class ElementScale, typename StrideS, class ElementZero, typename StrideZ, class TileShape, Scheduler TileScheduler, @@ -175,6 +176,7 @@ struct GemmConfiguration< template { - using LayoutD = LayoutC; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; @@ -205,7 +207,7 @@ struct MixedPrecisionGemmConfiguration< TileShape, ElementAccumulator, cutlass::gemm::TagToStrideC_t, - ElementC, + ElementD, cutlass::gemm::TagToStrideC_t, FusionCallBacks, XE_2D_U32x8x16_LD_N, From 66322ccdb418c483a1431dfe70bce5a3aac6f9ea Mon Sep 17 00:00:00 2001 From: "Joy, Albin" Date: Wed, 22 Oct 2025 14:35:28 +0000 Subject: [PATCH 5/5] Fix merge conflicts --- test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp b/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp index d061b33343..68934691a4 100644 --- a/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp +++ b/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp @@ -51,7 +51,7 @@ struct XE_Device_Gemm_fp16_fp16_f16_tensor_op_f32 { cute::half_t, LayoutA, cute::half_t, LayoutB, float, layout::RowMajor, - float>; + cute::half_t>; using Gemm = gemm::device::GemmUniversalAdapter,