Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 90 additions & 213 deletions test/unit/cute/intel_xe/mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,284 +30,161 @@
*
**************************************************************************************************/

#include "cutlass/detail/layout.hpp"
#include "cutlass_unit_test.h"

#include <cute/tensor.hpp>
#include <sycl/sycl.hpp>
#include <cute/util/compat.hpp>

#include "cutlass_unit_test.h"
#include "utils.hpp"
#include "../cooperative_gemm_common.hpp"

using namespace cute;
using namespace cutlass;
using namespace compat::experimental;

#define SUBGROUP_SIZE (16)

template<class...> class GemmDeviceName;

template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n,
uint32_t k) {
using namespace cute;

// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A),
make_layout(make_shape(m, k), make_stride(k, 1)));
Tensor mB = make_tensor(make_gmem_ptr(B),
make_layout(make_shape(n, k), make_stride(1, n)));
Tensor mC = make_tensor(make_gmem_ptr(C),
make_layout(make_shape(m, n), make_stride(n, 1)));

// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k)

auto cta_tiler =
make_shape(Int<wg_tile_m>{}, Int<wg_tile_n>{}, Int<sg_tile_k>{});
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{});
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{});
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{});

TiledMMA mma = make_tiled_mma(
MMA_Atom<MMA>{},
Layout< // Require: subgroup_layout
Shape<Int<cute::ceil_div(wg_tile_m, sg_tile_m)>,
Int<cute::ceil_div(wg_tile_n, sg_tile_n)>, _1>>{});

ThrMMA thrd_mma = mma.get_slice(ThreadIdxX());

Tensor tgA = thrd_mma.partition_A(gA);
Tensor fragment_A =
thrd_mma.make_fragment_A(tgA(_, _, _, 0)); // (MMA, MMA_M, MMA_K)

Tensor tgB = thrd_mma.partition_B(gB);
Tensor fragment_B =
thrd_mma.make_fragment_B(tgB(_, _, _, 0)); // (MMA, MMA_N, MMA_K)

Tensor tgC = thrd_mma.partition_C(gC);
Tensor fragment_C = thrd_mma.make_fragment_C(tgC); // (MMA, MMA_M, MMA_N)
clear(fragment_C);

#define CUTLASS_ENABLE_DEBUG_PRINTS (0)

#undef LOG_THREAD
#define LOG_THREAD (16)

#if CUTLASS_ENABLE_DEBUG_PRINTS
if (thread(LOG_THREAD)) {
print("===================== A :\n");

print(" mA : ");
print(mA);
print("\n");
print(" gA : ");
print(gA);
print("\n");
print("tgA : ");
print(tgA);
print("\n");
print("fragment_A : ");
print(fragment_A);
print("\n\n");
}
#endif

#if CUTLASS_ENABLE_DEBUG_PRINTS
if (thread(LOG_THREAD)) {
print("===================== B :\n");

print(" mB : ");
print(mB);
print("\n");
print(" gB : ");
print(gB);
print("\n");
print("tgB : ");
print(tgB);
print("\n");
print("fragment_B : ");
print(fragment_B);
print("\n\n");
}
#endif

#if CUTLASS_ENABLE_DEBUG_PRINTS
if (thread(LOG_THREAD)) {
print("===================== C :\n");

print(" mC : ");
print(mC);
print("\n");
print(" gC : ");
print(gC);
print("\n");
print("tgC : ");
print(tgC);
print("\n");
print("fragment_C : ");
print(fragment_C);
print("\n\n");
}
#endif

auto k_tile_max = size<3>(tgA);
for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) {
auto kA = tgA(_, _, _, k_tile);
auto kB = tgB(_, _, _, k_tile);
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
copy(kA, fragment_A);
copy(kB, fragment_B);

// Compute gemm on mma-partitioned smem
cute::gemm(mma, fragment_A, fragment_B, fragment_C);
}

copy(fragment_C, tgC);
namespace {
constexpr uint32_t thread_block_size = 128;
constexpr uint32_t max_vec_bits = 128;
}

// Setup params for a NT GEMM
template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
void gemm(int m, int n, int k, TA *A, TB *B, TC *C) {
using namespace cute;

auto dimBlock = compat::dim3(SUBGROUP_SIZE * (wg_tile_m * wg_tile_n) /
(sg_tile_m * sg_tile_n));
auto dimGrid = compat::dim3(size(ceil_div(m, wg_tile_m)),
size(ceil_div(n, wg_tile_n)));

launch<gemm_device<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k,
TA, TB, TC>, GemmDeviceName<MMA, TA, TB, TC>>(
launch_policy{dimGrid, dimBlock,
kernel_properties{sycl_exp::sub_group_size<SUBGROUP_SIZE>}},
A, B, C, m, n, k);
}

template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
void MMA_Test(int m, int n, int k) {
cutlass::host_vector<TA> h_A(m * k);
cutlass::host_vector<TB> h_B(n * k);
cutlass::host_vector<TC> h_C(m * n);

fill_matrix(h_A);
fill_matrix(h_B);

cutlass::device_vector<TA> d_A = h_A;
cutlass::device_vector<TB> d_B = h_B;
cutlass::device_vector<TC> d_C = h_C;

::gemm<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k>(
m, n, k, d_A.data(), d_B.data(), d_C.data());
compat::wait();

h_C = d_C;
verify(m, n, k, h_A.data(), h_B.data(), h_C.data());
template<typename MMAAtom, typename LayoutShape, typename ShapeMNK,
typename TA, typename TB, typename TC>
void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should pass ShapeMNK and LayoutShape only either as argument or template argument not both. If you choose the first place the template arguments at the end so that they can be deduced. If you choose the 2nd remove the arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

auto tiled_mma = TiledMMA<MMA_Atom<MMAAtom>, Layout<LayoutShape>>{};
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(
shape_mnk, tiled_mma);
}

TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) {
MMA_Test<XE_8x16x32_S32S8S8S32_TT, 64, 64, 8, 16, 32, int8_t, int8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_8x16x32_S32S8S8S32_TT, Shape<_2, _2, _1>,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the intention that the two of those are the same? If not, why not?
If you tried to keep them it's wrong because e.g. MNK would need to be 512,512,256 not 64,64,32.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same shape, such as <512 512 256>, cannot be used here because it will throw an OUT OF RESOURCE error during compilation. Since this unit test only verifies the correctness of atom MMA, the shape size has been reduced.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the error coming from?

In general please document in the PR message any non-obvious changes unrelated to the main goal of the PR. Helps both with code review and with later understanding why code was changed.

decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>(
Shape<_64, _64, _32>, int8_t, int8_t, int32_t>(

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32S8S8S32_TT) {
MMA_Test<XE_4x16x32_S32S8S8S32_TT, 32, 64, 4, 16, 32, int8_t, int8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_4x16x32_S32S8S8S32_TT, Shape<_2, _2, _1>,
decltype(Shape<_32, _64, _32>{}), int8_t, int8_t, int32_t>(
Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32S8S8S32_TT) {
MMA_Test<XE_2x16x32_S32S8S8S32_TT, 16, 64, 2, 16, 32, int8_t, int8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_2x16x32_S32S8S8S32_TT, Shape<_4, _2, _1>,
decltype(Shape<_16, _32, _32>{}), int8_t, int8_t, int32_t>(
Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32S8S8S32_TT) {
MMA_Test<XE_1x16x32_S32S8S8S32_TT, 8, 64, 1, 16, 32, int8_t, int8_t, int32_t>(
512, 512, 256);
run_mma_test<XE_1x16x32_S32S8S8S32_TT, Shape<_1, _1, _1>,
decltype(Shape<_8, _64, _32>{}), int8_t, int8_t, int32_t>(
Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32U8U8S32_TT) {
MMA_Test<XE_8x16x32_S32U8U8S32_TT, 64, 64, 8, 16, 32, uint8_t, uint8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_8x16x32_S32U8U8S32_TT, Shape<_2, _2, _1>,
decltype(Shape<_64, _64, _32>{}), uint8_t, uint8_t, int32_t>(
Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32U8U8S32_TT) {
MMA_Test<XE_4x16x32_S32U8U8S32_TT, 32, 64, 4, 16, 32, uint8_t, uint8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_4x16x32_S32U8U8S32_TT, Shape<_2, _2, _1>,
decltype(Shape<_32, _64, _32>{}), uint8_t, uint8_t, int32_t>(
Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32U8U8S32_TT) {
MMA_Test<XE_2x16x32_S32U8U8S32_TT, 16, 64, 2, 16, 32, uint8_t, uint8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_2x16x32_S32U8U8S32_TT, Shape<_4, _2, _1>,
decltype(Shape<_16, _32, _32>{}), uint8_t, uint8_t, int32_t>(
Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32U8U8S32_TT) {
MMA_Test<XE_1x16x32_S32U8U8S32_TT, 8, 64, 1, 16, 32, uint8_t, uint8_t,
int32_t>(512, 512, 256);
run_mma_test<XE_1x16x32_S32U8U8S32_TT, Shape<_1, _1, _1>,
decltype(Shape<_8, _64, _32>{}), uint8_t, uint8_t, int32_t>(
Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32BF16BF16F32_TT) {
MMA_Test<XE_8x16x16_F32BF16BF16F32_TT, 256, 256, 32, 64, 32, bfloat16_t,
bfloat16_t, float>(512, 512, 256);
run_mma_test<XE_8x16x16_F32BF16BF16F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_64, _64, _16>{}),
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32BF16BF16F32_TT) {
MMA_Test<XE_4x16x16_F32BF16BF16F32_TT, 32, 64, 4, 16, 16, bfloat16_t,
bfloat16_t, float>(512, 512, 256);
run_mma_test<XE_4x16x16_F32BF16BF16F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_32, _64, _16>{}),
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32BF16BF16F32_TT) {
MMA_Test<XE_2x16x16_F32BF16BF16F32_TT, 16, 64, 2, 16, 16, bfloat16_t,
bfloat16_t, float>(512, 512, 256);
run_mma_test<XE_2x16x16_F32BF16BF16F32_TT, Shape<_2, _4, _1>,
decltype(Shape<_128, _128, _16>{}),
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
Shape<_128, _128, _16>{}, Shape<_2, _4, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32BF16BF16F32_TT) {
MMA_Test<XE_1x16x16_F32BF16BF16F32_TT, 8, 64, 1, 16, 16, bfloat16_t,
bfloat16_t, float>(512, 512, 256);
run_mma_test<XE_1x16x16_F32BF16BF16F32_TT, Shape<_1, _1, _1>,
decltype(Shape<_8, _64, _16>{}),
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
Shape<_8, _64, _16>{}, Shape<_1, _1, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32F16F16F32_TT) {
MMA_Test<XE_8x16x16_F32F16F16F32_TT, 64, 64, 8, 16, 16, half_t, half_t,
float>(512, 512, 256);
run_mma_test<XE_8x16x16_F32F16F16F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_64, _64, _16>{}),
cutlass::half_t, cutlass::half_t, float>(
Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32F16F16F32_TT) {
MMA_Test<XE_4x16x16_F32F16F16F32_TT, 32, 64, 4, 16, 16, half_t, half_t,
float>(512, 512, 256);
run_mma_test<XE_4x16x16_F32F16F16F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_32, _64, _16>{}),
cutlass::half_t, cutlass::half_t, float>(
Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32F16F16F32_TT) {
MMA_Test<XE_2x16x16_F32F16F16F32_TT, 16, 64, 2, 16, 16, half_t, half_t,
float>(512, 512, 256);
run_mma_test<XE_2x16x16_F32F16F16F32_TT, Shape<_4, _2, _1>,
decltype(Shape<_128, _128, _16>{}),
cutlass::half_t, cutlass::half_t, float>(
Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32F16F16F32_TT) {
MMA_Test<XE_1x16x16_F32F16F16F32_TT, 8, 64, 1, 16, 16, half_t, half_t, float>(
512, 512, 256);
run_mma_test<XE_1x16x16_F32F16F16F32_TT, Shape<_1, _1, _1>,
decltype(Shape<_128, _128, _16>{}),
cutlass::half_t, cutlass::half_t, float>(
Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{});
}

TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) {
MMA_Test<UniversalFMA<float, float, float, float>, 64, 64, 8, 16, 16, float,
float, float>(512, 512, 256);
TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) {
run_mma_test<XE_8x16x8_F32TF32TF32F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_64, _64, _8>{}),
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
Shape<_64, _64, _8>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) {
MMA_Test<XE_1x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
tfloat32_t, float>(512, 512, 256);
TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) {
run_mma_test<XE_4x16x8_F32TF32TF32F32_TT, Shape<_2, _2, _1>,
decltype(Shape<_32, _64, _8>{}),
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
Shape<_32, _64, _8>{}, Shape<_2, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_2x16x8_F32TF32TF32F32_TT) {
MMA_Test<XE_2x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
tfloat32_t, float>(512, 512, 256);
run_mma_test<XE_2x16x8_F32TF32TF32F32_TT, Shape<_4, _2, _1>,
decltype(Shape<_128, _128, _8>{}),
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
Shape<_128, _128, _8>{}, Shape<_4, _2, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) {
MMA_Test<XE_4x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
tfloat32_t, float>(512, 512, 256);
TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) {
run_mma_test<XE_1x16x8_F32TF32TF32F32_TT, Shape<_1, _1, _1>,
decltype(Shape<_8, _64, _8>{}),
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
Shape<_8, _64, _8>{}, Shape<_1, _1, _1>{});
}

TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) {
MMA_Test<XE_8x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 32, tfloat32_t,
tfloat32_t, float>(512, 512, 256);
TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) {
run_mma_test<UniversalFMA<float, float, float, float>, Shape<_1, _1, _1>,
decltype(Shape<_128, _128, _8>{}), float, float, float>(
Shape<_128, _128, _8>{}, Shape<_1, _1, _1>{});
}
Loading