@@ -69,31 +69,31 @@ using DummyAmdgcnMma = amdgcn_mma<fp32_t,
6969
7070/* ! @struct MmaDefaultSelector
7171 * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
72- * @tparam ADataType Data type of matrix A
73- * @tparam BDataType Data type of matrix B
74- * @tparam CDataType Data type of the accumulator
75- * @tparam ChunkM Size of the M dimension of the chunk to decompose
76- * @tparam ChunkN Size of the N dimension of the chunk to decompose
77- * @tparam ChunkK Size of the K dimension of the chunk to decompose
72+ * @tparam ADataType Data type of matrix A
73+ * @tparam BDataType Data type of matrix B
74+ * @tparam CDataType Data type of the accumulator
75+ * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
76+ * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
77+ * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
7878 * @tparam CompilerTarget The compiler target
79- * @tparam OpFamily The MMA operation family
79+ * @tparam OpFamily The MMA operation family
8080 */
8181template <typename ADataType,
8282 typename BDataType,
8383 typename CDataType,
84- uint32_t ChunkM ,
85- uint32_t ChunkN ,
86- uint32_t ChunkK ,
84+ uint32_t WaveTileM ,
85+ uint32_t WaveTileN ,
86+ uint32_t WaveTileK ,
8787 typename CompilerTarget,
8888 MmaOpFamily OpFamily>
8989// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
9090// TODO: requires
9191struct MmaDefaultSelector <ADataType,
9292 BDataType,
9393 CDataType,
94- ChunkM ,
95- ChunkN ,
96- ChunkK ,
94+ WaveTileM ,
95+ WaveTileN ,
96+ WaveTileK ,
9797 CompilerTarget,
9898 OpFamily,
9999 enable_if_all<enable_if_target_id_dummy_t <CompilerTarget>,
@@ -311,11 +311,11 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
311311 EXPECT_FALSE (MmaOpTraits<SelectedMma>::IsSupported);
312312}
313313
314- // Test MmaDefaultSelector for supported DummyAmdgcnMma on chunk sizes other than 16x16x16
315- // This tests that the selector can still pick the correct MMA op even if the chunk sizes differ
316- TEST (TestAmdgcnMma, MmaDefaultSelectorSupportedChunk )
314+ // Test MmaDefaultSelector for supported DummyAmdgcnMma on WaveTile sizes other than 16x16x16
315+ // This tests that the selector can still pick the correct MMA op even if the WaveTile sizes differ
316+ TEST (TestAmdgcnMma, MmaDefaultSelectorSupportedWaveTile )
317317{
318- // Select indirectly with a chunk size of 256x128x64
318+ // Select indirectly with a WaveTile size of 256x128x64
319319 using SelectedMma = MmaDefaultSelector<fp32_t ,
320320 fp32_t ,
321321 fp32_t ,
@@ -332,8 +332,8 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedChunk)
332332 EXPECT_TRUE (MmaOpTraits<SelectedMma>::IsSupported);
333333}
334334
335- // Test MmaDefaultSelector for a different chunk size and supported arch
336- TEST (TestAmdgcnMma, MmaDefaultSelectorUnsupportedChunk )
335+ // Test MmaDefaultSelector for a different WaveTile size and supported arch
336+ TEST (TestAmdgcnMma, MmaDefaultSelectorUnsupportedWaveTile )
337337{
338338 // This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
339339 using SelectedMma = MmaDefaultSelector<fp32_t ,
@@ -367,34 +367,34 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
367367// Test on real hardware for MmaOp selection.
368368// This is not a GEMM kernel, but a simple test to ensure that the selected MmaOp works correctly on
369369// real hardware. Assumption: inputs are all 1's The multiply-accumulate functionality can be tested
370- // here by looping over the k dimension and accumulating the results. They should be equal to ChunkK
371- // regardless of hardware.
370+ // here by looping over the k dimension and accumulating the results. They should be equal to
371+ // WaveTileK regardless of hardware.
372372template <typename ADataType,
373373 typename BDataType,
374374 typename CDataType,
375- uint32_t ChunkM ,
376- uint32_t ChunkN ,
377- uint32_t ChunkK >
375+ uint32_t WaveTileM ,
376+ uint32_t WaveTileN ,
377+ uint32_t WaveTileK >
378378__global__ void test_accum_over_k (void * a, void * b, void * c, void * out)
379379{
380380 using Selector = MmaDefaultSelector<ADataType,
381381 BDataType,
382382 CDataType,
383- ChunkM ,
384- ChunkN ,
385- ChunkK ,
383+ WaveTileM ,
384+ WaveTileN ,
385+ WaveTileK ,
386386 decltype (get_compiler_target ()),
387387 MmaOpFamily::DENSE>;
388388
389389 using MmaOp = typename Selector::SelectedOp;
390390 using CVecType = typename MmaOp::CVecType;
391391
392- static constexpr uint32_t kIters = ChunkK / MmaOp::kK ;
392+ static constexpr uint32_t kIters = WaveTileK / MmaOp::kK ;
393393
394394 // Initialize the accumulator
395395 CVecType result = *reinterpret_cast <typename MmaOp::CVecType*>(c);
396396
397- // Accumulate input AxB over ChunkK /FragK iterations
397+ // Accumulate input AxB over WaveTileK /FragK iterations
398398 for (uint32_t i = 0 ; i < kIters ; ++i)
399399 {
400400 result = MmaOp::exec (*reinterpret_cast <typename MmaOp::AVecType*>(a),
@@ -430,16 +430,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
430430 using BType = fp16_t ;
431431 using CType = fp32_t ;
432432
433- // Chunk size, also the expected fragment size from the selector.
433+ // WaveTile size, also the expected fragment size (MmaTile) from the selector.
434434 // Note: Actual FragK might be slightly different due to hardware implementation, but the
435435 // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
436436 // correct.
437- static constexpr uint32_t ChunkM = 16 ;
438- static constexpr uint32_t ChunkN = 16 ;
439- static constexpr uint32_t ChunkK = 32 ;
440- static constexpr uint32_t FragM = ChunkM ;
441- static constexpr uint32_t FragN = ChunkN ;
442- static constexpr uint32_t FragK = ChunkK ;
437+ static constexpr uint32_t WaveTileM = 16 ;
438+ static constexpr uint32_t WaveTileN = 16 ;
439+ static constexpr uint32_t WaveTileK = 32 ;
440+ static constexpr uint32_t FragM = WaveTileM ;
441+ static constexpr uint32_t FragN = WaveTileN ;
442+ static constexpr uint32_t FragK = WaveTileK ;
443443
444444 // Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
445445 // TODO: c++20 use is_target_family_gfx11(currentArchId)
@@ -480,16 +480,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
480480 HIP_CHECK_ERROR (hipMemcpy (d_c, h_c.data (), CSize, hipMemcpyHostToDevice));
481481
482482 const auto wave_size = getDeviceWaveSize ();
483- test_accum_over_k<AType, BType, CType, ChunkM, ChunkN, ChunkK >
483+ test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK >
484484 <<<1 , wave_size>>>(d_a, d_b, d_c, d_out);
485485 HIP_CHECK_ERROR (hipDeviceSynchronize ());
486486
487487 HIP_CHECK_ERROR (hipMemcpy (h_out.data (), d_out, CSize, hipMemcpyDeviceToHost));
488488
489- // Output should be ChunkK for all elements, because the inputs are all 1's
489+ // Output should be WaveTileK for all elements, because the inputs are all 1's
490490 for (size_t i = 0 ; i < CElements; ++i)
491491 {
492- CType expected = static_cast <CType>(ChunkK );
492+ CType expected = static_cast <CType>(WaveTileK );
493493
494494 EXPECT_NEAR (h_out[i], expected, 1e-3 );
495495 }
@@ -502,7 +502,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
502502
503503// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32
504504// The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the
505- // chunk sizes are larger than 16x16x32. This tests that the selector can handle larger chunk
505+ // WaveTile sizes are larger than 16x16x32. This tests that the selector can handle larger WaveTile
506506// sizes and still select the correct MmaOp.
507507TEST (TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
508508{
@@ -528,13 +528,13 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
528528 using BType = fp16_t ;
529529 using CType = fp32_t ;
530530
531- // Chunk size to test for decomposition.
532- // We expect the selector to pick a 16x16 chunk
533- static constexpr uint32_t ChunkM = 112 ;
534- static constexpr uint32_t ChunkN = 112 ;
535- static constexpr uint32_t ChunkK = 128 ;
531+ // WaveTile size to test for decomposition.
532+ // We expect the selector to pick a 16x16 WaveTile
533+ static constexpr uint32_t WaveTileM = 112 ;
534+ static constexpr uint32_t WaveTileN = 112 ;
535+ static constexpr uint32_t WaveTileK = 128 ;
536536
537- // The expected fragment size from the selector (multiple of 16).
537+ // The expected fragment size from the selector (MmaTile, multiple of 16).
538538 // Note: Actual FragK might be slightly different due to hardware implementation, but the
539539 // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
540540 // correct.
@@ -581,16 +581,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
581581 HIP_CHECK_ERROR (hipMemcpy (d_c, h_c.data (), CSize, hipMemcpyHostToDevice));
582582
583583 const auto wave_size = getDeviceWaveSize ();
584- test_accum_over_k<AType, BType, CType, ChunkM, ChunkN, ChunkK >
584+ test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK >
585585 <<<1 , wave_size>>>(d_a, d_b, d_c, d_out);
586586 HIP_CHECK_ERROR (hipDeviceSynchronize ());
587587
588588 HIP_CHECK_ERROR (hipMemcpy (h_out.data (), d_out, CSize, hipMemcpyDeviceToHost));
589589
590- // Output should be ChunkK for all elements, because the inputs are all 1's
590+ // Output should be WaveTileK for all elements, because the inputs are all 1's
591591 for (size_t i = 0 ; i < CElements; ++i)
592592 {
593- CType expected = static_cast <CType>(ChunkK );
593+ CType expected = static_cast <CType>(WaveTileK );
594594
595595 EXPECT_NEAR (h_out[i], expected, 1e-3 );
596596 }
0 commit comments