From 6e2a6ac2b927ee4eddf0dc4a6a39eaa0b6561a90 Mon Sep 17 00:00:00 2001 From: yhma Date: Mon, 13 Oct 2025 16:38:46 +0800 Subject: [PATCH 1/2] Rewrite mma unit tests --- test/unit/cute/intel_xe/mma.cpp | 303 ++++++++++---------------------- 1 file changed, 90 insertions(+), 213 deletions(-) diff --git a/test/unit/cute/intel_xe/mma.cpp b/test/unit/cute/intel_xe/mma.cpp index d30c8ae8d7..9c67a1bc68 100755 --- a/test/unit/cute/intel_xe/mma.cpp +++ b/test/unit/cute/intel_xe/mma.cpp @@ -30,284 +30,161 @@ * **************************************************************************************************/ -#include "cutlass/detail/layout.hpp" +#include "cutlass_unit_test.h" #include -#include -#include -#include "cutlass_unit_test.h" -#include "utils.hpp" +#include "../cooperative_gemm_common.hpp" using namespace cute; -using namespace cutlass; -using namespace compat::experimental; - -#define SUBGROUP_SIZE (16) - -template class GemmDeviceName; - -template -void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, - uint32_t k) { - using namespace cute; - - // Represent the full tensors - Tensor mA = make_tensor(make_gmem_ptr(A), - make_layout(make_shape(m, k), make_stride(k, 1))); - Tensor mB = make_tensor(make_gmem_ptr(B), - make_layout(make_shape(n, k), make_stride(1, n))); - Tensor mC = make_tensor(make_gmem_ptr(C), - make_layout(make_shape(m, n), make_stride(n, 1))); - - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k) - - auto cta_tiler = - make_shape(Int{}, Int{}, Int{}); - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); - - TiledMMA mma = make_tiled_mma( - MMA_Atom{}, - Layout< // Require: subgroup_layout - Shape, - Int, _1>>{}); - - ThrMMA thrd_mma = mma.get_slice(ThreadIdxX()); - - Tensor tgA = thrd_mma.partition_A(gA); - Tensor fragment_A = - thrd_mma.make_fragment_A(tgA(_, _, _, 0)); // (MMA, MMA_M, MMA_K) - - Tensor tgB = thrd_mma.partition_B(gB); - Tensor fragment_B = - thrd_mma.make_fragment_B(tgB(_, _, _, 0)); // (MMA, MMA_N, MMA_K) - - Tensor tgC = thrd_mma.partition_C(gC); - Tensor fragment_C = thrd_mma.make_fragment_C(tgC); // (MMA, MMA_M, MMA_N) - clear(fragment_C); - -#define CUTLASS_ENABLE_DEBUG_PRINTS (0) - -#undef LOG_THREAD -#define LOG_THREAD (16) - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD)) { - print("===================== A :\n"); - - print(" mA : "); - print(mA); - print("\n"); - print(" gA : "); - print(gA); - print("\n"); - print("tgA : "); - print(tgA); - print("\n"); - print("fragment_A : "); - print(fragment_A); - print("\n\n"); - } -#endif - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD)) { - print("===================== B :\n"); - - print(" mB : "); - print(mB); - print("\n"); - print(" gB : "); - print(gB); - print("\n"); - print("tgB : "); - print(tgB); - print("\n"); - print("fragment_B : "); - print(fragment_B); - print("\n\n"); - } -#endif - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD)) { - print("===================== C :\n"); - print(" mC : "); - print(mC); - print("\n"); - print(" gC : "); - print(gC); - print("\n"); - print("tgC : "); - print(tgC); - print("\n"); - print("fragment_C : "); - print(fragment_C); - print("\n\n"); - } -#endif - - auto k_tile_max = size<3>(tgA); - for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { - auto kA = tgA(_, _, _, k_tile); - auto kB = tgB(_, _, _, k_tile); - // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors - copy(kA, fragment_A); - copy(kB, fragment_B); - - // Compute gemm on mma-partitioned smem - cute::gemm(mma, fragment_A, fragment_B, fragment_C); - } - - copy(fragment_C, tgC); +namespace { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; } -// Setup params for a NT GEMM -template -void gemm(int m, int n, int k, TA *A, TB *B, TC *C) { - using namespace cute; - - auto dimBlock = compat::dim3(SUBGROUP_SIZE * (wg_tile_m * wg_tile_n) / - (sg_tile_m * sg_tile_n)); - auto dimGrid = compat::dim3(size(ceil_div(m, wg_tile_m)), - size(ceil_div(n, wg_tile_n))); - - launch, GemmDeviceName>( - launch_policy{dimGrid, dimBlock, - kernel_properties{sycl_exp::sub_group_size}}, - A, B, C, m, n, k); -} - -template -void MMA_Test(int m, int n, int k) { - cutlass::host_vector h_A(m * k); - cutlass::host_vector h_B(n * k); - cutlass::host_vector h_C(m * n); - - fill_matrix(h_A); - fill_matrix(h_B); - - cutlass::device_vector d_A = h_A; - cutlass::device_vector d_B = h_B; - cutlass::device_vector d_C = h_C; - - ::gemm( - m, n, k, d_A.data(), d_B.data(), d_C.data()); - compat::wait(); - - h_C = d_C; - verify(m, n, k, h_A.data(), h_B.data(), h_C.data()); +template +void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) { + auto tiled_mma = TiledMMA, Layout>{}; + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>( + Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32S8S8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_32, _64, _32>{}), int8_t, int8_t, int32_t>( + Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32S8S8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_16, _32, _32>{}), int8_t, int8_t, int32_t>( + Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32S8S8S32_TT) { - MMA_Test( - 512, 512, 256); + run_mma_test, + decltype(Shape<_8, _64, _32>{}), int8_t, int8_t, int32_t>( + Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32U8U8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_64, _64, _32>{}), uint8_t, uint8_t, int32_t>( + Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32U8U8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_32, _64, _32>{}), uint8_t, uint8_t, int32_t>( + Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32U8U8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_16, _32, _32>{}), uint8_t, uint8_t, int32_t>( + Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32U8U8S32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_8, _64, _32>{}), uint8_t, uint8_t, int32_t>( + Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32BF16BF16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_64, _64, _16>{}), + cutlass::bfloat16_t, cutlass::bfloat16_t, float>( + Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32BF16BF16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_32, _64, _16>{}), + cutlass::bfloat16_t, cutlass::bfloat16_t, float>( + Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32BF16BF16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_128, _128, _16>{}), + cutlass::bfloat16_t, cutlass::bfloat16_t, float>( + Shape<_128, _128, _16>{}, Shape<_2, _4, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32BF16BF16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_8, _64, _16>{}), + cutlass::bfloat16_t, cutlass::bfloat16_t, float>( + Shape<_8, _64, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32F16F16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_64, _64, _16>{}), + cutlass::half_t, cutlass::half_t, float>( + Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32F16F16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_32, _64, _16>{}), + cutlass::half_t, cutlass::half_t, float>( + Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32F16F16F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_128, _128, _16>{}), + cutlass::half_t, cutlass::half_t, float>( + Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32F16F16F32_TT) { - MMA_Test( - 512, 512, 256); + run_mma_test, + decltype(Shape<_128, _128, _16>{}), + cutlass::half_t, cutlass::half_t, float>( + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } -TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) { - MMA_Test, 64, 64, 8, 16, 16, float, - float, float>(512, 512, 256); +TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) { + run_mma_test, + decltype(Shape<_64, _64, _8>{}), + cutlass::tfloat32_t, cutlass::tfloat32_t, float>( + Shape<_64, _64, _8>{}, Shape<_2, _2, _1>{}); } -TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) { - MMA_Test(512, 512, 256); +TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) { + run_mma_test, + decltype(Shape<_32, _64, _8>{}), + cutlass::tfloat32_t, cutlass::tfloat32_t, float>( + Shape<_32, _64, _8>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x8_F32TF32TF32F32_TT) { - MMA_Test(512, 512, 256); + run_mma_test, + decltype(Shape<_128, _128, _8>{}), + cutlass::tfloat32_t, cutlass::tfloat32_t, float>( + Shape<_128, _128, _8>{}, Shape<_4, _2, _1>{}); } -TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) { - MMA_Test(512, 512, 256); +TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) { + run_mma_test, + decltype(Shape<_8, _64, _8>{}), + cutlass::tfloat32_t, cutlass::tfloat32_t, float>( + Shape<_8, _64, _8>{}, Shape<_1, _1, _1>{}); } -TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) { - MMA_Test(512, 512, 256); +TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) { + run_mma_test, Shape<_1, _1, _1>, + decltype(Shape<_128, _128, _8>{}), float, float, float>( + Shape<_128, _128, _8>{}, Shape<_1, _1, _1>{}); } From 607254f755dcb74931d94bfe1c1f1949640a57d2 Mon Sep 17 00:00:00 2001 From: yhma Date: Tue, 14 Oct 2025 09:23:12 +0800 Subject: [PATCH 2/2] Fix comments & failed CI --- test/unit/cute/intel_xe/mma.cpp | 103 +++++++++++++------------------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/test/unit/cute/intel_xe/mma.cpp b/test/unit/cute/intel_xe/mma.cpp index 9c67a1bc68..b4ff2634ad 100755 --- a/test/unit/cute/intel_xe/mma.cpp +++ b/test/unit/cute/intel_xe/mma.cpp @@ -43,8 +43,8 @@ namespace { constexpr uint32_t max_vec_bits = 128; } -template +template void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) { auto tiled_mma = TiledMMA, Layout>{}; test_cooperative_gemm_col_major_layout( @@ -52,139 +52,118 @@ void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) { } TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) { - run_mma_test, - decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>( - Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32S8S8S32_TT) { - run_mma_test, - decltype(Shape<_32, _64, _32>{}), int8_t, int8_t, int32_t>( - Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32S8S8S32_TT) { - run_mma_test, - decltype(Shape<_16, _32, _32>{}), int8_t, int8_t, int32_t>( - Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32S8S8S32_TT) { - run_mma_test, - decltype(Shape<_8, _64, _32>{}), int8_t, int8_t, int32_t>( - Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32U8U8S32_TT) { - run_mma_test, - decltype(Shape<_64, _64, _32>{}), uint8_t, uint8_t, int32_t>( - Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32U8U8S32_TT) { - run_mma_test, - decltype(Shape<_32, _64, _32>{}), uint8_t, uint8_t, int32_t>( - Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32U8U8S32_TT) { - run_mma_test, - decltype(Shape<_16, _32, _32>{}), uint8_t, uint8_t, int32_t>( - Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32U8U8S32_TT) { - run_mma_test, - decltype(Shape<_8, _64, _32>{}), uint8_t, uint8_t, int32_t>( - Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{}); + run_mma_test( + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32BF16BF16F32_TT) { - run_mma_test, - decltype(Shape<_64, _64, _16>{}), + run_mma_test( - Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32BF16BF16F32_TT) { - run_mma_test, - decltype(Shape<_32, _64, _16>{}), + run_mma_test( - Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32BF16BF16F32_TT) { - run_mma_test, - decltype(Shape<_128, _128, _16>{}), + run_mma_test( Shape<_128, _128, _16>{}, Shape<_2, _4, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32BF16BF16F32_TT) { - run_mma_test, - decltype(Shape<_8, _64, _16>{}), + run_mma_test( - Shape<_8, _64, _16>{}, Shape<_1, _1, _1>{}); + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32F16F16F32_TT) { - run_mma_test, - decltype(Shape<_64, _64, _16>{}), + run_mma_test( - Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32F16F16F32_TT) { - run_mma_test, - decltype(Shape<_32, _64, _16>{}), + run_mma_test( - Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32F16F16F32_TT) { - run_mma_test, - decltype(Shape<_128, _128, _16>{}), + run_mma_test( Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32F16F16F32_TT) { - run_mma_test, - decltype(Shape<_128, _128, _16>{}), + run_mma_test( Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) { - run_mma_test, - decltype(Shape<_64, _64, _8>{}), + run_mma_test( - Shape<_64, _64, _8>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) { - run_mma_test, - decltype(Shape<_32, _64, _8>{}), + run_mma_test( - Shape<_32, _64, _8>{}, Shape<_2, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_2, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_2x16x8_F32TF32TF32F32_TT) { - run_mma_test, - decltype(Shape<_128, _128, _8>{}), + run_mma_test( - Shape<_128, _128, _8>{}, Shape<_4, _2, _1>{}); + Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{}); } TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) { - run_mma_test, - decltype(Shape<_8, _64, _8>{}), + run_mma_test( - Shape<_8, _64, _8>{}, Shape<_1, _1, _1>{}); + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); } TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) { - run_mma_test, Shape<_1, _1, _1>, - decltype(Shape<_128, _128, _8>{}), float, float, float>( - Shape<_128, _128, _8>{}, Shape<_1, _1, _1>{}); + run_mma_test, float, float, float>( + Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{}); }