-
Notifications
You must be signed in to change notification settings - Fork 64
Rewrite mma unit tests #557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -30,284 +30,161 @@ | |||||
| * | ||||||
| **************************************************************************************************/ | ||||||
|
|
||||||
| #include "cutlass/detail/layout.hpp" | ||||||
| #include "cutlass_unit_test.h" | ||||||
|
|
||||||
| #include <cute/tensor.hpp> | ||||||
| #include <sycl/sycl.hpp> | ||||||
| #include <cute/util/compat.hpp> | ||||||
|
|
||||||
| #include "cutlass_unit_test.h" | ||||||
| #include "utils.hpp" | ||||||
| #include "../cooperative_gemm_common.hpp" | ||||||
|
|
||||||
| using namespace cute; | ||||||
| using namespace cutlass; | ||||||
| using namespace compat::experimental; | ||||||
|
|
||||||
| #define SUBGROUP_SIZE (16) | ||||||
|
|
||||||
| template<class...> class GemmDeviceName; | ||||||
|
|
||||||
| template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, | ||||||
| uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC> | ||||||
| void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, | ||||||
| uint32_t k) { | ||||||
| using namespace cute; | ||||||
|
|
||||||
| // Represent the full tensors | ||||||
| Tensor mA = make_tensor(make_gmem_ptr(A), | ||||||
| make_layout(make_shape(m, k), make_stride(k, 1))); | ||||||
| Tensor mB = make_tensor(make_gmem_ptr(B), | ||||||
| make_layout(make_shape(n, k), make_stride(1, n))); | ||||||
| Tensor mC = make_tensor(make_gmem_ptr(C), | ||||||
| make_layout(make_shape(m, n), make_stride(n, 1))); | ||||||
|
|
||||||
| // Get the appropriate blocks for this thread block | ||||||
| auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k) | ||||||
|
|
||||||
| auto cta_tiler = | ||||||
| make_shape(Int<wg_tile_m>{}, Int<wg_tile_n>{}, Int<sg_tile_k>{}); | ||||||
| Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); | ||||||
| Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); | ||||||
| Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); | ||||||
|
|
||||||
| TiledMMA mma = make_tiled_mma( | ||||||
| MMA_Atom<MMA>{}, | ||||||
| Layout< // Require: subgroup_layout | ||||||
| Shape<Int<cute::ceil_div(wg_tile_m, sg_tile_m)>, | ||||||
| Int<cute::ceil_div(wg_tile_n, sg_tile_n)>, _1>>{}); | ||||||
|
|
||||||
| ThrMMA thrd_mma = mma.get_slice(ThreadIdxX()); | ||||||
|
|
||||||
| Tensor tgA = thrd_mma.partition_A(gA); | ||||||
| Tensor fragment_A = | ||||||
| thrd_mma.make_fragment_A(tgA(_, _, _, 0)); // (MMA, MMA_M, MMA_K) | ||||||
|
|
||||||
| Tensor tgB = thrd_mma.partition_B(gB); | ||||||
| Tensor fragment_B = | ||||||
| thrd_mma.make_fragment_B(tgB(_, _, _, 0)); // (MMA, MMA_N, MMA_K) | ||||||
|
|
||||||
| Tensor tgC = thrd_mma.partition_C(gC); | ||||||
| Tensor fragment_C = thrd_mma.make_fragment_C(tgC); // (MMA, MMA_M, MMA_N) | ||||||
| clear(fragment_C); | ||||||
|
|
||||||
| #define CUTLASS_ENABLE_DEBUG_PRINTS (0) | ||||||
|
|
||||||
| #undef LOG_THREAD | ||||||
| #define LOG_THREAD (16) | ||||||
|
|
||||||
| #if CUTLASS_ENABLE_DEBUG_PRINTS | ||||||
| if (thread(LOG_THREAD)) { | ||||||
| print("===================== A :\n"); | ||||||
|
|
||||||
| print(" mA : "); | ||||||
| print(mA); | ||||||
| print("\n"); | ||||||
| print(" gA : "); | ||||||
| print(gA); | ||||||
| print("\n"); | ||||||
| print("tgA : "); | ||||||
| print(tgA); | ||||||
| print("\n"); | ||||||
| print("fragment_A : "); | ||||||
| print(fragment_A); | ||||||
| print("\n\n"); | ||||||
| } | ||||||
| #endif | ||||||
|
|
||||||
| #if CUTLASS_ENABLE_DEBUG_PRINTS | ||||||
| if (thread(LOG_THREAD)) { | ||||||
| print("===================== B :\n"); | ||||||
|
|
||||||
| print(" mB : "); | ||||||
| print(mB); | ||||||
| print("\n"); | ||||||
| print(" gB : "); | ||||||
| print(gB); | ||||||
| print("\n"); | ||||||
| print("tgB : "); | ||||||
| print(tgB); | ||||||
| print("\n"); | ||||||
| print("fragment_B : "); | ||||||
| print(fragment_B); | ||||||
| print("\n\n"); | ||||||
| } | ||||||
| #endif | ||||||
|
|
||||||
| #if CUTLASS_ENABLE_DEBUG_PRINTS | ||||||
| if (thread(LOG_THREAD)) { | ||||||
| print("===================== C :\n"); | ||||||
|
|
||||||
| print(" mC : "); | ||||||
| print(mC); | ||||||
| print("\n"); | ||||||
| print(" gC : "); | ||||||
| print(gC); | ||||||
| print("\n"); | ||||||
| print("tgC : "); | ||||||
| print(tgC); | ||||||
| print("\n"); | ||||||
| print("fragment_C : "); | ||||||
| print(fragment_C); | ||||||
| print("\n\n"); | ||||||
| } | ||||||
| #endif | ||||||
|
|
||||||
| auto k_tile_max = size<3>(tgA); | ||||||
| for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { | ||||||
| auto kA = tgA(_, _, _, k_tile); | ||||||
| auto kB = tgB(_, _, _, k_tile); | ||||||
| // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors | ||||||
| copy(kA, fragment_A); | ||||||
| copy(kB, fragment_B); | ||||||
|
|
||||||
| // Compute gemm on mma-partitioned smem | ||||||
| cute::gemm(mma, fragment_A, fragment_B, fragment_C); | ||||||
| } | ||||||
|
|
||||||
| copy(fragment_C, tgC); | ||||||
| namespace { | ||||||
| constexpr uint32_t thread_block_size = 128; | ||||||
| constexpr uint32_t max_vec_bits = 128; | ||||||
| } | ||||||
|
|
||||||
| // Setup params for a NT GEMM | ||||||
| template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, | ||||||
| uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC> | ||||||
| void gemm(int m, int n, int k, TA *A, TB *B, TC *C) { | ||||||
| using namespace cute; | ||||||
|
|
||||||
| auto dimBlock = compat::dim3(SUBGROUP_SIZE * (wg_tile_m * wg_tile_n) / | ||||||
| (sg_tile_m * sg_tile_n)); | ||||||
| auto dimGrid = compat::dim3(size(ceil_div(m, wg_tile_m)), | ||||||
| size(ceil_div(n, wg_tile_n))); | ||||||
|
|
||||||
| launch<gemm_device<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, | ||||||
| TA, TB, TC>, GemmDeviceName<MMA, TA, TB, TC>>( | ||||||
| launch_policy{dimGrid, dimBlock, | ||||||
| kernel_properties{sycl_exp::sub_group_size<SUBGROUP_SIZE>}}, | ||||||
| A, B, C, m, n, k); | ||||||
| } | ||||||
|
|
||||||
| template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, | ||||||
| uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC> | ||||||
| void MMA_Test(int m, int n, int k) { | ||||||
| cutlass::host_vector<TA> h_A(m * k); | ||||||
| cutlass::host_vector<TB> h_B(n * k); | ||||||
| cutlass::host_vector<TC> h_C(m * n); | ||||||
|
|
||||||
| fill_matrix(h_A); | ||||||
| fill_matrix(h_B); | ||||||
|
|
||||||
| cutlass::device_vector<TA> d_A = h_A; | ||||||
| cutlass::device_vector<TB> d_B = h_B; | ||||||
| cutlass::device_vector<TC> d_C = h_C; | ||||||
|
|
||||||
| ::gemm<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k>( | ||||||
| m, n, k, d_A.data(), d_B.data(), d_C.data()); | ||||||
| compat::wait(); | ||||||
|
|
||||||
| h_C = d_C; | ||||||
| verify(m, n, k, h_A.data(), h_B.data(), h_C.data()); | ||||||
| template<typename MMAAtom, typename LayoutShape, typename ShapeMNK, | ||||||
| typename TA, typename TB, typename TC> | ||||||
| void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) { | ||||||
| auto tiled_mma = TiledMMA<MMA_Atom<MMAAtom>, Layout<LayoutShape>>{}; | ||||||
| test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>( | ||||||
| shape_mnk, tiled_mma); | ||||||
| } | ||||||
|
|
||||||
| TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) { | ||||||
| MMA_Test<XE_8x16x32_S32S8S8S32_TT, 64, 64, 8, 16, 32, int8_t, int8_t, | ||||||
| int32_t>(512, 512, 256); | ||||||
| run_mma_test<XE_8x16x32_S32S8S8S32_TT, Shape<_2, _2, _1>, | ||||||
|
||||||
| decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>( | ||||||
|
||||||
| decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>( | |
| Shape<_64, _64, _32>, int8_t, int8_t, int32_t>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should pass ShapeMNK and LayoutShape only either as argument or template argument not both. If you choose the first place the template arguments at the end so that they can be deduced. If you choose the 2nd remove the arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed