Skip to content

Commit 2dbdd57

Browse files
committed
Address review comments: tweak comments + change Chunk to WaveTile + reduce dummy exec print verbosity.
1 parent 35a0c20 commit 2dbdd57

File tree

7 files changed

+269
-252
lines changed

7 files changed

+269
-252
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ namespace ck_tile::core::arch::mma {
1818
* Meaning of amdgcn_mma layout parameters (general)
1919
* ---------------------------------------------------
2020
*
21-
* The fragment sizes and layout constants in the amdgcn_mma struct describe the mapping between
22-
* intrinsic input / output matrix elements and vector registers (lane x vector_item space). Note
23-
* that we end up having a mapping for A, B and C separately, although those for A and B are usually
24-
* similar if not identical. All mappings can be described as an unmerge operation on one of the
25-
* matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and raw
26-
* other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on a
27-
* dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
21+
* The fragment (MmaTile) sizes and layout constants in the amdgcn_mma struct describe the mapping
22+
* between intrinsic input / output matrix elements and vector registers (lane x vector_item space).
23+
* Note that we end up having a mapping for A, B and C separately, although those for A and B are
24+
* usually similar if not identical. All mappings can be described as an unmerge operation on one of
25+
* the matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and
26+
* raw other dim into the Lane and Vector_item dimensions. When considering an unmerge operation on
27+
* a dimension K, we can label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
2828
* of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
2929
* K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
3030
* unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
@@ -96,7 +96,7 @@ namespace ck_tile::core::arch::mma {
9696
*
9797
* -- A / B Repeat --
9898
* Variable indicating that all matrix values are represented multiple times in the vector
99-
* reigsters, typically repeating in the lane dimension. This is always equal to the repeat value
99+
* registers, typically repeating in the lane dimension. This is always equal to the repeat value
100100
* used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
101101
* here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
102102
*
@@ -143,7 +143,7 @@ struct amdgcn_mma_base
143143
using BDataType = BDataType_;
144144
using CDataType = CDataType_;
145145

146-
// Fragment sizes, check description above.
146+
// Fragment (MmaTile) sizes, check description above.
147147
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
148148
static constexpr index_t kN = FragN;
149149
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
@@ -224,9 +224,9 @@ concept MmaOpI = requires(MmaOp op) {
224224
* @tparam ADataType Datatype of input A
225225
* @tparam BDataType Datatype of input B
226226
* @tparam CDataType Datatype of accumulator
227-
* @tparam FragM M-dimension of mma intrinsic
228-
* @tparam FragN N-dimension of mma intrinsic
229-
* @tparam FragK K-dimension of mma intrinsic
227+
* @tparam FragM M-dimension of mma intrinsic (MmaTile)
228+
* @tparam FragN N-dimension of mma intrinsic (MmaTile)
229+
* @tparam FragK K-dimension of mma intrinsic (MmaTile)
230230
* @tparam CtrlFlags Control flags for mma operation
231231
* @tparam CompilerTarget The current compiler target
232232
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
@@ -251,7 +251,13 @@ struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1
251251
CK_TILE_DEVICE static CVecType const&
252252
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
253253
{
254-
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
254+
// Prints once across all thread blocks and threads.
255+
static __device__ int printed = 0;
256+
if(threadIdx.x == 0 && atomicCAS(&printed, 0, 1) == 0)
257+
{
258+
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
259+
}
260+
255261
ignore(regsA, regsB);
256262
return regsC; // No-op, just return C
257263
}

projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,27 @@ namespace ck_tile::core::arch::mma {
1818
* @class MfmaDefaultSelector
1919
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
2020
* This implements the K dimension search strategy to find the largest supported MFMA
21-
* instruction for the given M/N chunk sizes and datatypes.
22-
* If no supported instruction is found, falls back to an unsupported pass-through
23-
implementation.
24-
* @tparam ADataType Data type of matrix A
25-
* @tparam BDataType Data type of matrix B
26-
* @tparam CDataType Data type of the accumulator
27-
* @tparam ChunkM Chunk M dimension size
28-
* @tparam ChunkN Chunk N dimension size
29-
* @tparam ChunkKTest Current Chunk K dimension size to test
21+
* instruction for the given M/N WaveTile sizes and datatypes.
22+
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
23+
* @tparam ADataType Data type of matrix A
24+
* @tparam BDataType Data type of matrix B
25+
* @tparam CDataType Data type of the accumulator
26+
* @tparam WaveTileM WaveTile M dimension size
27+
* @tparam WaveTileN WaveTile N dimension size
28+
* @tparam WaveTileKTest Current WaveTile K dimension size to test
3029
* @tparam CompilerTarget The compiler target
31-
* @note Here we assume that ChunkKTest is always a power-of-two integer.
32-
* The search strategy starts from a maximum ChunkKTest size down to 1u by halving
30+
* @note Here we assume that WaveTileKTest is always a power-of-two integer.
31+
* The search strategy starts from a maximum WaveTileKTest size down to 1u by halving
3332
* each time.
3433
*/
3534
template <typename ADataType,
3635
typename BDataType,
3736
typename CDataType,
38-
uint32_t ChunkM,
39-
uint32_t ChunkN,
40-
uint32_t ChunkKTest,
37+
uint32_t WaveTileM,
38+
uint32_t WaveTileN,
39+
uint32_t WaveTileKTest,
4140
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
42-
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(ChunkKTest))
41+
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest))
4342
struct MfmaDefaultSelector
4443
{
4544
private:
@@ -48,54 +47,60 @@ struct MfmaDefaultSelector
4847
amdgcn_mma<ADataType,
4948
BDataType,
5049
CDataType,
51-
ChunkM,
52-
ChunkN,
53-
ChunkKTest,
50+
WaveTileM,
51+
WaveTileN,
52+
WaveTileKTest,
5453
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
5554
CompilerTarget,
5655
MmaOpFamily::DENSE>;
5756

5857
public:
5958
// If the candidate is supported (e.g., a backend implementation exists), then select it.
60-
// Otherwise, test another smaller ChunkK. If no existing implementations, we will get ChunkK=0u
61-
// and fall back to the unsupported pass-through implementation.
59+
// Otherwise, test another smaller WaveTileK. If no existing implementations, we will get
60+
// WaveTileK=0u and fall back to the unsupported pass-through implementation.
6261
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
6362
CandidateOp,
6463
typename MfmaDefaultSelector<ADataType,
6564
BDataType,
6665
CDataType,
67-
ChunkM,
68-
ChunkN,
69-
ChunkKTest / 2u,
66+
WaveTileM,
67+
WaveTileN,
68+
WaveTileKTest / 2u,
7069
CompilerTarget>::SelectedOp>;
7170
};
7271

7372
/**
7473
* @struct MfmaDefaultSelector
7574
* @brief Implements the base case for the default MFMA selector when no supported instruction is
7675
* found.
77-
* @tparam ADataType Data type of matrix A
78-
* @tparam BDataType Data type of matrix B
79-
* @tparam CDataType Data type of the accumulator
80-
* @tparam ChunkM Chunk M dimension size
81-
* @tparam ChunkN Chunk N dimension size
76+
* @tparam ADataType Data type of matrix A
77+
* @tparam BDataType Data type of matrix B
78+
* @tparam CDataType Data type of the accumulator
79+
* @tparam WaveTileM WaveTile M dimension size
80+
* @tparam WaveTileN WaveTile N dimension size
8281
* @tparam CompilerTarget The compiler target
8382
*/
8483
template <typename ADataType,
8584
typename BDataType,
8685
typename CDataType,
87-
uint32_t ChunkM,
88-
uint32_t ChunkN,
86+
uint32_t WaveTileM,
87+
uint32_t WaveTileN,
8988
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
90-
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, ChunkM, ChunkN, 1u, CompilerTarget>
89+
struct MfmaDefaultSelector<ADataType,
90+
BDataType,
91+
CDataType,
92+
WaveTileM,
93+
WaveTileN,
94+
1u,
95+
CompilerTarget>
9196
{
9297
// Default unsupported pass-through if no instruction is found
9398
using SelectedOp =
9499
amdgcn_mma<ADataType,
95100
BDataType,
96101
CDataType,
97-
ChunkM,
98-
ChunkN,
102+
WaveTileM,
103+
WaveTileN,
99104
1u,
100105
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
101106
CompilerTarget,
@@ -105,32 +110,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, ChunkM, ChunkN, 1u,
105110
/**
106111
* @struct MmaDefaultSelector
107112
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
108-
* This implements the M/N chunk size search strategy to find the largest supported MFMA
113+
* This implements the M/N WaveTile size search strategy to find the largest supported MFMA
109114
* instruction for the given datatypes.
110115
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
111-
* @tparam ADataType Data type of matrix A
112-
* @tparam BDataType Data type of matrix B
113-
* @tparam CDataType Data type of the accumulator
114-
* @tparam ChunkM Size of the M dimension of the chunk to decompose
115-
* @tparam ChunkN Size of the N dimension of the chunk to decompose
116-
* @tparam ChunkK Size of the K dimension of the chunk to decompose
116+
* @tparam ADataType Data type of matrix A
117+
* @tparam BDataType Data type of matrix B
118+
* @tparam CDataType Data type of the accumulator
119+
* @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
120+
* @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
121+
* @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
117122
* @tparam CompilerTarget The compiler target
118-
* @tparam OpFamily The MMA operation family
123+
* @tparam OpFamily The MMA operation family
119124
*/
120125
template <typename ADataType,
121126
typename BDataType,
122127
typename CDataType,
123-
uint32_t ChunkM,
124-
uint32_t ChunkN,
125-
uint32_t ChunkK,
128+
uint32_t WaveTileM,
129+
uint32_t WaveTileN,
130+
uint32_t WaveTileK,
126131
typename CompilerTarget,
127132
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
128133
struct MmaDefaultSelector<ADataType,
129134
BDataType,
130135
CDataType,
131-
ChunkM,
132-
ChunkN,
133-
ChunkK,
136+
WaveTileM,
137+
WaveTileN,
138+
WaveTileK,
134139
CompilerTarget,
135140
OpFamily,
136141
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
@@ -162,20 +167,20 @@ struct MmaDefaultSelector<ADataType,
162167
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
163168
SelectedOp;
164169

165-
// Check if each candidate is supported for the given chunk sizes
166-
// For this case, we require the chunk sizes to be multiples of the MFMA shape
170+
// Check if each candidate is supported for the given WaveTile sizes
171+
// For this case, we require the WaveTile sizes to be multiples of the MFMA shape
167172
static constexpr bool IsSupported4x4 =
168-
MmaOpTraits<CandidateOp4x4>::IsSupported && (ChunkM % CandidateOp4x4::kM == 0u) &&
169-
(ChunkN % CandidateOp4x4::kN == 0u) && (ChunkK % CandidateOp4x4::kK == 0u);
173+
MmaOpTraits<CandidateOp4x4>::IsSupported && (WaveTileM % CandidateOp4x4::kM == 0u) &&
174+
(WaveTileN % CandidateOp4x4::kN == 0u) && (WaveTileK % CandidateOp4x4::kK == 0u);
170175
static constexpr bool IsSupported16x16 =
171-
MmaOpTraits<CandidateOp16x16>::IsSupported && (ChunkM % CandidateOp16x16::kM == 0u) &&
172-
(ChunkN % CandidateOp16x16::kN == 0u) && (ChunkK % CandidateOp16x16::kK == 0u);
176+
MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) &&
177+
(WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u);
173178
static constexpr bool IsSupported32x32 =
174-
MmaOpTraits<CandidateOp32x32>::IsSupported && (ChunkM % CandidateOp32x32::kM == 0u) &&
175-
(ChunkN % CandidateOp32x32::kN == 0u) && (ChunkK % CandidateOp32x32::kK == 0u);
179+
MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) &&
180+
(WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u);
176181

177182
public:
178-
// Select the largest supported MFMA operation for the given chunk shape
183+
// Select the largest supported MFMA operation for the given WaveTile shape
179184
using SelectedOp = std::conditional_t<
180185
IsSupported32x32,
181186
CandidateOp32x32,

0 commit comments

Comments
 (0)