Skip to content

Commit 9f4f0a7

Browse files
committed
Change some remaining references to Chunk to WaveTile
1 parent 2dbdd57 commit 9f4f0a7

File tree

2 files changed

+64
-64
lines changed

2 files changed

+64
-64
lines changed

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -69,31 +69,31 @@ using DummyAmdgcnMma = amdgcn_mma<fp32_t,
6969

7070
/*! @struct MmaDefaultSelector
7171
* @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both
72-
* @tparam ADataType Data type of matrix A
73-
* @tparam BDataType Data type of matrix B
74-
* @tparam CDataType Data type of the accumulator
75-
* @tparam ChunkM Size of the M dimension of the chunk to decompose
76-
* @tparam ChunkN Size of the N dimension of the chunk to decompose
77-
* @tparam ChunkK Size of the K dimension of the chunk to decompose
72+
* @tparam ADataType Data type of matrix A
73+
* @tparam BDataType Data type of matrix B
74+
* @tparam CDataType Data type of the accumulator
75+
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
76+
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
77+
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
7878
* @tparam CompilerTarget The compiler target
79-
* @tparam OpFamily The MMA operation family
79+
* @tparam OpFamily The MMA operation family
8080
*/
8181
template <typename ADataType,
8282
typename BDataType,
8383
typename CDataType,
84-
uint32_t ChunkM,
85-
uint32_t ChunkN,
86-
uint32_t ChunkK,
84+
uint32_t WaveTileM,
85+
uint32_t WaveTileN,
86+
uint32_t WaveTileK,
8787
typename CompilerTarget,
8888
MmaOpFamily OpFamily>
8989
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
9090
// TODO: requires
9191
struct MmaDefaultSelector<ADataType,
9292
BDataType,
9393
CDataType,
94-
ChunkM,
95-
ChunkN,
96-
ChunkK,
94+
WaveTileM,
95+
WaveTileN,
96+
WaveTileK,
9797
CompilerTarget,
9898
OpFamily,
9999
enable_if_all<enable_if_target_id_dummy_t<CompilerTarget>,
@@ -311,11 +311,11 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported)
311311
EXPECT_FALSE(MmaOpTraits<SelectedMma>::IsSupported);
312312
}
313313

314-
// Test MmaDefaultSelector for supported DummyAmdgcnMma on chunk sizes other than 16x16x16
315-
// This tests that the selector can still pick the correct MMA op even if the chunk sizes differ
316-
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedChunk)
314+
// Test MmaDefaultSelector for supported DummyAmdgcnMma on WaveTile sizes other than 16x16x16
315+
// This tests that the selector can still pick the correct MMA op even if the WaveTile sizes differ
316+
TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedWaveTile)
317317
{
318-
// Select indirectly with a chunk size of 256x128x64
318+
// Select indirectly with a WaveTile size of 256x128x64
319319
using SelectedMma = MmaDefaultSelector<fp32_t,
320320
fp32_t,
321321
fp32_t,
@@ -332,8 +332,8 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedChunk)
332332
EXPECT_TRUE(MmaOpTraits<SelectedMma>::IsSupported);
333333
}
334334

335-
// Test MmaDefaultSelector for a different chunk size and supported arch
336-
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedChunk)
335+
// Test MmaDefaultSelector for a different WaveTile size and supported arch
336+
TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedWaveTile)
337337
{
338338
// This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16
339339
using SelectedMma = MmaDefaultSelector<fp32_t,
@@ -367,34 +367,34 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported)
367367
// Test on real hardware for MmaOp selection.
368368
// This is not a GEMM kernel, but a simple test to ensure that the selected MmaOp works correctly on
369369
// real hardware. Assumption: inputs are all 1's The multiply-accumulate functionality can be tested
370-
// here by looping over the k dimension and accumulating the results. They should be equal to ChunkK
371-
// regardless of hardware.
370+
// here by looping over the k dimension and accumulating the results. They should be equal to
371+
// WaveTileK regardless of hardware.
372372
template <typename ADataType,
373373
typename BDataType,
374374
typename CDataType,
375-
uint32_t ChunkM,
376-
uint32_t ChunkN,
377-
uint32_t ChunkK>
375+
uint32_t WaveTileM,
376+
uint32_t WaveTileN,
377+
uint32_t WaveTileK>
378378
__global__ void test_accum_over_k(void* a, void* b, void* c, void* out)
379379
{
380380
using Selector = MmaDefaultSelector<ADataType,
381381
BDataType,
382382
CDataType,
383-
ChunkM,
384-
ChunkN,
385-
ChunkK,
383+
WaveTileM,
384+
WaveTileN,
385+
WaveTileK,
386386
decltype(get_compiler_target()),
387387
MmaOpFamily::DENSE>;
388388

389389
using MmaOp = typename Selector::SelectedOp;
390390
using CVecType = typename MmaOp::CVecType;
391391

392-
static constexpr uint32_t kIters = ChunkK / MmaOp::kK;
392+
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
393393

394394
// Initialize the accumulator
395395
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
396396

397-
// Accumulate input AxB over ChunkK/FragK iterations
397+
// Accumulate input AxB over WaveTileK/FragK iterations
398398
for(uint32_t i = 0; i < kIters; ++i)
399399
{
400400
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
@@ -430,16 +430,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
430430
using BType = fp16_t;
431431
using CType = fp32_t;
432432

433-
// Chunk size, also the expected fragment size from the selector.
433+
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
434434
// Note: Actual FragK might be slightly different due to hardware implementation, but the
435435
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
436436
// correct.
437-
static constexpr uint32_t ChunkM = 16;
438-
static constexpr uint32_t ChunkN = 16;
439-
static constexpr uint32_t ChunkK = 32;
440-
static constexpr uint32_t FragM = ChunkM;
441-
static constexpr uint32_t FragN = ChunkN;
442-
static constexpr uint32_t FragK = ChunkK;
437+
static constexpr uint32_t WaveTileM = 16;
438+
static constexpr uint32_t WaveTileN = 16;
439+
static constexpr uint32_t WaveTileK = 32;
440+
static constexpr uint32_t FragM = WaveTileM;
441+
static constexpr uint32_t FragN = WaveTileN;
442+
static constexpr uint32_t FragK = WaveTileK;
443443

444444
// Gfx11 has input data duplication and no accumulator padding (MultiplierC = 1)
445445
// TODO: c++20 use is_target_family_gfx11(currentArchId)
@@ -480,16 +480,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
480480
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
481481

482482
const auto wave_size = getDeviceWaveSize();
483-
test_accum_over_k<AType, BType, CType, ChunkM, ChunkN, ChunkK>
483+
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
484484
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
485485
HIP_CHECK_ERROR(hipDeviceSynchronize());
486486

487487
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
488488

489-
// Output should be ChunkK for all elements, because the inputs are all 1's
489+
// Output should be WaveTileK for all elements, because the inputs are all 1's
490490
for(size_t i = 0; i < CElements; ++i)
491491
{
492-
CType expected = static_cast<CType>(ChunkK);
492+
CType expected = static_cast<CType>(WaveTileK);
493493

494494
EXPECT_NEAR(h_out[i], expected, 1e-3);
495495
}
@@ -502,7 +502,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
502502

503503
// Do a live test. At minimum, there should be a solution on real hardware for F16_F16_F32_16x16x32
504504
// The selector should be able to pick the correct MmaOp as a multiple of 16x16x32, even if the
505-
// chunk sizes are larger than 16x16x32. This tests that the selector can handle larger chunk
505+
// WaveTile sizes are larger than 16x16x32. This tests that the selector can handle larger WaveTile
506506
// sizes and still select the correct MmaOp.
507507
TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
508508
{
@@ -528,13 +528,13 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
528528
using BType = fp16_t;
529529
using CType = fp32_t;
530530

531-
// Chunk size to test for decomposition.
532-
// We expect the selector to pick a 16x16 chunk
533-
static constexpr uint32_t ChunkM = 112;
534-
static constexpr uint32_t ChunkN = 112;
535-
static constexpr uint32_t ChunkK = 128;
531+
// WaveTile size to test for decomposition.
532+
// We expect the selector to pick a 16x16 WaveTile
533+
static constexpr uint32_t WaveTileM = 112;
534+
static constexpr uint32_t WaveTileN = 112;
535+
static constexpr uint32_t WaveTileK = 128;
536536

537-
// The expected fragment size from the selector (multiple of 16).
537+
// The expected fragment size from the selector (MmaTile, multiple of 16).
538538
// Note: Actual FragK might be slightly different due to hardware implementation, but the
539539
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
540540
// correct.
@@ -581,16 +581,16 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
581581
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
582582

583583
const auto wave_size = getDeviceWaveSize();
584-
test_accum_over_k<AType, BType, CType, ChunkM, ChunkN, ChunkK>
584+
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
585585
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
586586
HIP_CHECK_ERROR(hipDeviceSynchronize());
587587

588588
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
589589

590-
// Output should be ChunkK for all elements, because the inputs are all 1's
590+
// Output should be WaveTileK for all elements, because the inputs are all 1's
591591
for(size_t i = 0; i < CElements; ++i)
592592
{
593-
CType expected = static_cast<CType>(ChunkK);
593+
CType expected = static_cast<CType>(WaveTileK);
594594

595595
EXPECT_NEAR(h_out[i], expected, 1e-3);
596596
}

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,29 +144,29 @@ TEST(SparseMMATrait, SparseSelector)
144144
template <typename AType,
145145
typename BType,
146146
typename CType,
147-
uint32_t ChunkM,
148-
uint32_t ChunkN,
149-
uint32_t ChunkK>
147+
uint32_t WaveTileM,
148+
uint32_t WaveTileN,
149+
uint32_t WaveTileK>
150150
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
151151
{
152152
using CompilerTarget = decltype(get_compiler_target());
153153
using Selector = MmaDefaultSelector<AType,
154154
BType,
155155
CType,
156-
ChunkM,
157-
ChunkN,
158-
ChunkK,
156+
WaveTileM,
157+
WaveTileN,
158+
WaveTileK,
159159
CompilerTarget,
160160
MmaOpFamily::SPARSE>;
161161
using MmaOp = typename Selector::SelectedOp;
162162
using CVecType = typename MmaOp::CVecType;
163163

164-
static constexpr uint32_t kIters = ChunkK / MmaOp::kK;
164+
static constexpr uint32_t kIters = WaveTileK / MmaOp::kK;
165165

166166
// Initialize the accumulator
167167
CVecType result = *reinterpret_cast<typename MmaOp::CVecType*>(c);
168168

169-
// Accumulate input AxB over ChunkK/FragK iterations
169+
// Accumulate input AxB over WaveTileK/FragK iterations
170170
for(uint32_t i = 0; i < kIters; ++i)
171171
{
172172
result = MmaOp::exec(*reinterpret_cast<typename MmaOp::AVecType*>(a),
@@ -207,16 +207,16 @@ TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
207207
using BType = fp16_t;
208208
using CType = fp32_t;
209209

210-
// Chunk size, also the expected fragment size from the selector.
210+
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
211211
// Note: Actual FragK might be slightly different due to hardware implementation, but the
212212
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
213213
// correct.
214-
static constexpr uint32_t ChunkM = 16;
215-
static constexpr uint32_t ChunkN = 16;
216-
static constexpr uint32_t ChunkK = 32;
217-
static constexpr uint32_t FragM = ChunkM;
218-
static constexpr uint32_t FragN = ChunkN;
219-
static constexpr uint32_t FragK = ChunkK;
214+
static constexpr uint32_t WaveTileM = 16;
215+
static constexpr uint32_t WaveTileN = 16;
216+
static constexpr uint32_t WaveTileK = 32;
217+
static constexpr uint32_t FragM = WaveTileM;
218+
static constexpr uint32_t FragN = WaveTileN;
219+
static constexpr uint32_t FragK = WaveTileK;
220220

221221
// The number of elements per thread
222222
uint32_t AElements = FragM * FragK / deviceWarpSize;

0 commit comments

Comments
 (0)