Skip to content

Commit 7a1ca1a

Browse files
committed
Address PR comments except for those about layout explanations. Also missed a number of Block / Frag / Chunk refactor spots.
1 parent 9ccb00e commit 7a1ca1a

File tree

14 files changed

+246
-338
lines changed

14 files changed

+246
-338
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace ck_tile::core::arch::mma {
1717
// TODO: Describe layout params.
1818
/**
1919
* @class amdgcn_mma_base
20-
* @brief Helper base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
20+
* @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
2121
* all generic parameter derivations and static asserts in one place. Houses all of the
2222
* amdgcn struct types and variables, except for the exec() function.
2323
*/
@@ -127,14 +127,13 @@ concept MmaOpI = requires(MmaOp op) {
127127
* @tparam ADataType Datatype of input A
128128
* @tparam BDataType Datatype of input B
129129
* @tparam CDataType Datatype of accumulator
130-
* @tparam FragM M-dimension of mma block
131-
* @tparam FragN N-dimension of mma block
132-
* @tparam FragK K-dimension of mma block
130+
* @tparam FragM M-dimension of mma intrinsic
131+
* @tparam FragN N-dimension of mma intrinsic
132+
* @tparam FragK K-dimension of mma intrinsic
133133
* @tparam CtrlFlags Control flags for mma operation
134134
* @tparam CompilerTarget The current compiler target
135135
* @tparam Enabler SFINAE enabler
136136
*/
137-
// clang-format off
138137
template <typename ADataType,
139138
typename BDataType,
140139
typename CDataType,
@@ -145,17 +144,19 @@ template <typename ADataType,
145144
typename CompilerTarget,
146145
MmaOpFamily OpFamily_,
147146
typename Enabler = void>
147+
// clang-format off
148148
struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1, 1, 1, 1, 1, 1, Unsupported, MmaOpFamily::UNDEFINED>
149+
// clang-format on
149150
{
150151
// This is a default pass-through implementation that doesn't do anything practical.
151152
CK_TILE_DEVICE static CVecType const&
152153
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
153154
{
155+
printf("[WARNING] Running amdgcn_mma dummy exec function!\n");
154156
ignore(regsA, regsB);
155157
return regsC; // No-op, just return C
156158
}
157159
};
158-
// clang-format on
159160

160161
} // namespace ck_tile::core::arch::mma
161162
#pragma clang diagnostic pop

projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,18 @@ concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
5353
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
5454
*
5555
* This specialization implements the MFMA instruction for fp16_t A and B
56-
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
56+
* matrices, and fp32_t accumulator matrix, with 16x16x16 fragment sizes.
5757
*
5858
* @tparam CtrlFlags Control flags for the MFMA operation
5959
* @tparam CompilerTarget Current compiler target
6060
*/
6161
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
6262
// TODO: c++20 requires
63-
// clang-format off
6463
template <typename CtrlFlags, typename CompilerTarget>
64+
// clang-format off
6565
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
6666
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 64u, 4, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
67+
// clang-format on
6768
{
6869
CK_TILE_DEVICE static auto
6970
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
@@ -82,16 +83,18 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
8283
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
8384
*
8485
* This specialization implements the MFMA instruction for fp16_t A and B
85-
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
86+
* matrices, and fp32_t accumulator matrix, with 16x16x32 fragment sizes.
8687
*
8788
* @tparam CtrlFlags Control flags for the MFMA operation
8889
* @tparam CompilerTarget Current compiler target
8990
*/
9091
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
9192
// TODO: c++20 requires
9293
template <typename CtrlFlags, typename CompilerTarget>
94+
// clang-format off
9395
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
9496
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
97+
// clang-format on
9598
{
9699
CK_TILE_DEVICE static auto
97100
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
@@ -104,6 +107,5 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarg
104107
static_cast<int>(CtrlFlags::Blgp))};
105108
}
106109
};
107-
// clang-format on
108110

109111
} // namespace ck_tile::core::arch::mma

projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,28 @@ namespace ck_tile::core::arch::mma {
1818
* @class MfmaDefaultSelector
1919
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
2020
* This implements the K dimension search strategy to find the largest supported MFMA
21-
* instruction for the given M/N block sizes and datatypes.
21+
* instruction for the given M/N chunk sizes and datatypes.
2222
* If no supported instruction is found, falls back to an unsupported pass-through
2323
implementation.
2424
* @tparam ADataType Data type of matrix A
2525
* @tparam BDataType Data type of matrix B
2626
* @tparam CDataType Data type of the accumulator
27-
* @tparam FragM Block M dimension size
28-
* @tparam FragN Block N dimension size
29-
* @tparam FragKTest Current Block K dimension size to test
27+
* @tparam ChunkM Chunk M dimension size
28+
* @tparam ChunkN Chunk N dimension size
29+
* @tparam ChunkKTest Current Chunk K dimension size to test
3030
* @tparam CompilerTarget The compiler target
31-
* @note Here we assume that FragKTest is always a power-of-two integer.
32-
* The search strategy starts from a maximum FragKTest size down to 1u by halving
31+
* @note Here we assume that ChunkKTest is always a power-of-two integer.
32+
* The search strategy starts from a maximum ChunkKTest size down to 1u by halving
3333
* each time.
3434
*/
3535
template <typename ADataType,
3636
typename BDataType,
3737
typename CDataType,
38-
uint32_t FragM,
39-
uint32_t FragN,
40-
uint32_t FragKTest,
38+
uint32_t ChunkM,
39+
uint32_t ChunkN,
40+
uint32_t ChunkKTest,
4141
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
42-
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(FragKTest))
42+
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(ChunkKTest))
4343
struct MfmaDefaultSelector
4444
{
4545
private:
@@ -48,25 +48,25 @@ struct MfmaDefaultSelector
4848
amdgcn_mma<ADataType,
4949
BDataType,
5050
CDataType,
51-
FragM,
52-
FragN,
53-
FragKTest,
51+
ChunkM,
52+
ChunkN,
53+
ChunkKTest,
5454
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
5555
CompilerTarget,
5656
MmaOpFamily::DENSE>;
5757

5858
public:
5959
// If the candidate is supported (e.g., a backend implementation exists), then select it.
60-
// Otherwise, test another smaller FragK. If no existing implementations, we will get FragK=0u
60+
// Otherwise, test another smaller ChunkK. If no existing implementations, we will get ChunkK=0u
6161
// and fall back to the unsupported pass-through implementation.
6262
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
6363
CandidateOp,
6464
typename MfmaDefaultSelector<ADataType,
6565
BDataType,
6666
CDataType,
67-
FragM,
68-
FragN,
69-
FragKTest / 2u,
67+
ChunkM,
68+
ChunkN,
69+
ChunkKTest / 2u,
7070
CompilerTarget>::SelectedOp>;
7171
};
7272

@@ -77,25 +77,25 @@ struct MfmaDefaultSelector
7777
* @tparam ADataType Data type of matrix A
7878
* @tparam BDataType Data type of matrix B
7979
* @tparam CDataType Data type of the accumulator
80-
* @tparam FragM Block M dimension size
81-
* @tparam FragN Block N dimension size
80+
* @tparam ChunkM Chunk M dimension size
81+
* @tparam ChunkN Chunk N dimension size
8282
* @tparam CompilerTarget The compiler target
8383
*/
8484
template <typename ADataType,
8585
typename BDataType,
8686
typename CDataType,
87-
uint32_t FragM,
88-
uint32_t FragN,
87+
uint32_t ChunkM,
88+
uint32_t ChunkN,
8989
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
90-
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, FragM, FragN, 1u, CompilerTarget>
90+
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, ChunkM, ChunkN, 1u, CompilerTarget>
9191
{
9292
// Default unsupported pass-through if no instruction is found
9393
using SelectedOp =
9494
amdgcn_mma<ADataType,
9595
BDataType,
9696
CDataType,
97-
FragM,
98-
FragN,
97+
ChunkM,
98+
ChunkN,
9999
1u,
100100
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
101101
CompilerTarget,
@@ -105,32 +105,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, FragM, FragN, 1u, Co
105105
/**
106106
* @struct MmaDefaultSelector
107107
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
108-
* This implements the M/N block size search strategy to find the largest supported MFMA
108+
* This implements the M/N chunk size search strategy to find the largest supported MFMA
109109
* instruction for the given datatypes.
110110
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
111111
* @tparam ADataType Data type of matrix A
112112
* @tparam BDataType Data type of matrix B
113113
* @tparam CDataType Data type of the accumulator
114-
* @tparam FragM Size of the M dimension of the fragment to decompose
115-
* @tparam FragN Size of the N dimension of the fragment to decompose
116-
* @tparam FragK Size of the K dimension of the fragment to decompose
114+
* @tparam ChunkM Size of the M dimension of the chunk to decompose
115+
* @tparam ChunkN Size of the N dimension of the chunk to decompose
116+
* @tparam ChunkK Size of the K dimension of the chunk to decompose
117117
* @tparam CompilerTarget The compiler target
118118
* @tparam OpFamily The MMA operation family
119119
*/
120120
template <typename ADataType,
121121
typename BDataType,
122122
typename CDataType,
123-
uint32_t FragM,
124-
uint32_t FragN,
125-
uint32_t FragK,
123+
uint32_t ChunkM,
124+
uint32_t ChunkN,
125+
uint32_t ChunkK,
126126
typename CompilerTarget,
127127
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
128128
struct MmaDefaultSelector<ADataType,
129129
BDataType,
130130
CDataType,
131-
FragM,
132-
FragN,
133-
FragK,
131+
ChunkM,
132+
ChunkN,
133+
ChunkK,
134134
CompilerTarget,
135135
OpFamily,
136136
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
@@ -162,23 +162,20 @@ struct MmaDefaultSelector<ADataType,
162162
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
163163
SelectedOp;
164164

165-
// Check if each candidate is supported for the given fragment sizes
166-
// For this case, we require the fragment sizes to be multiples of the MFMA shape
167-
static constexpr bool IsSupported4x4 = MmaOpTraits<CandidateOp4x4>::IsSupported &&
168-
(FragM % CandidateOp4x4::kM == 0u) &&
169-
(FragN % CandidateOp4x4::kN == 0u) &&
170-
(FragK % CandidateOp4x4::kK == 0u);
171-
static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
172-
(FragM % CandidateOp16x16::kM == 0u) &&
173-
(FragN % CandidateOp16x16::kN == 0u) &&
174-
(FragK % CandidateOp16x16::kK == 0u);
175-
static constexpr bool IsSupported32x32 = MmaOpTraits<CandidateOp32x32>::IsSupported &&
176-
(FragM % CandidateOp32x32::kM == 0u) &&
177-
(FragN % CandidateOp32x32::kN == 0u) &&
178-
(FragK % CandidateOp32x32::kK == 0u);
165+
// Check if each candidate is supported for the given chunk sizes
166+
// For this case, we require the chunk sizes to be multiples of the MFMA shape
167+
static constexpr bool IsSupported4x4 =
168+
MmaOpTraits<CandidateOp4x4>::IsSupported && (ChunkM % CandidateOp4x4::kM == 0u) &&
169+
(ChunkN % CandidateOp4x4::kN == 0u) && (ChunkK % CandidateOp4x4::kK == 0u);
170+
static constexpr bool IsSupported16x16 =
171+
MmaOpTraits<CandidateOp16x16>::IsSupported && (ChunkM % CandidateOp16x16::kM == 0u) &&
172+
(ChunkN % CandidateOp16x16::kN == 0u) && (ChunkK % CandidateOp16x16::kK == 0u);
173+
static constexpr bool IsSupported32x32 =
174+
MmaOpTraits<CandidateOp32x32>::IsSupported && (ChunkM % CandidateOp32x32::kM == 0u) &&
175+
(ChunkN % CandidateOp32x32::kN == 0u) && (ChunkK % CandidateOp32x32::kK == 0u);
179176

180177
public:
181-
// Select the largest supported MFMA operation for the given fragment shape
178+
// Select the largest supported MFMA operation for the given chunk shape
182179
using SelectedOp = std::conditional_t<
183180
IsSupported32x32,
184181
CandidateOp32x32,

0 commit comments

Comments
 (0)