@@ -18,28 +18,27 @@ namespace ck_tile::core::arch::mma {
1818 * @class MfmaDefaultSelector
1919 * @brief Implements a default MFMA selector strategy for gfx9 target architectures.
2020 * This implements the K dimension search strategy to find the largest supported MFMA
21- * instruction for the given M/N chunk sizes and datatypes.
22- * If no supported instruction is found, falls back to an unsupported pass-through
23- implementation.
24- * @tparam ADataType Data type of matrix A
25- * @tparam BDataType Data type of matrix B
26- * @tparam CDataType Data type of the accumulator
27- * @tparam ChunkM Chunk M dimension size
28- * @tparam ChunkN Chunk N dimension size
29- * @tparam ChunkKTest Current Chunk K dimension size to test
21+ * instruction for the given M/N WaveTile sizes and datatypes.
22+ * If no supported instruction is found, falls back to an unsupported pass-through implementation.
23+ * @tparam ADataType Data type of matrix A
24+ * @tparam BDataType Data type of matrix B
25+ * @tparam CDataType Data type of the accumulator
26+ * @tparam WaveTileM WaveTile M dimension size
27+ * @tparam WaveTileN WaveTile N dimension size
28+ * @tparam WaveTileKTest Current WaveTile K dimension size to test
3029 * @tparam CompilerTarget The compiler target
31- * @note Here we assume that ChunkKTest is always a power-of-two integer.
32- * The search strategy starts from a maximum ChunkKTest size down to 1u by halving
30+ * @note Here we assume that WaveTileKTest is always a power-of-two integer.
31+ * The search strategy starts from a maximum WaveTileKTest size down to 1u by halving
3332 * each time.
3433 */
3534template <typename ADataType,
3635 typename BDataType,
3736 typename CDataType,
38- uint32_t ChunkM ,
39- uint32_t ChunkN ,
40- uint32_t ChunkKTest ,
37+ uint32_t WaveTileM ,
38+ uint32_t WaveTileN ,
39+ uint32_t WaveTileKTest ,
4140 typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
42- // TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(ChunkKTest ))
41+ // TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(WaveTileKTest ))
4342struct MfmaDefaultSelector
4443{
4544 private:
@@ -48,54 +47,60 @@ struct MfmaDefaultSelector
4847 amdgcn_mma<ADataType,
4948 BDataType,
5049 CDataType,
51- ChunkM ,
52- ChunkN ,
53- ChunkKTest ,
50+ WaveTileM ,
51+ WaveTileN ,
52+ WaveTileKTest ,
5453 DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
5554 CompilerTarget,
5655 MmaOpFamily::DENSE>;
5756
5857 public:
5958 // If the candidate is supported (e.g., a backend implementation exists), then select it.
60- // Otherwise, test another smaller ChunkK . If no existing implementations, we will get ChunkK=0u
61- // and fall back to the unsupported pass-through implementation.
59+ // Otherwise, test another smaller WaveTileK . If no existing implementations, we will get
60+ // WaveTileK=0u and fall back to the unsupported pass-through implementation.
6261 using SelectedOp = std::conditional_t <MmaOpTraits<CandidateOp>::IsSupported,
6362 CandidateOp,
6463 typename MfmaDefaultSelector<ADataType,
6564 BDataType,
6665 CDataType,
67- ChunkM ,
68- ChunkN ,
69- ChunkKTest / 2u ,
66+ WaveTileM ,
67+ WaveTileN ,
68+ WaveTileKTest / 2u ,
7069 CompilerTarget>::SelectedOp>;
7170};
7271
7372/* *
7473 * @struct MfmaDefaultSelector
7574 * @brief Implements the base case for the default MFMA selector when no supported instruction is
7675 * found.
77- * @tparam ADataType Data type of matrix A
78- * @tparam BDataType Data type of matrix B
79- * @tparam CDataType Data type of the accumulator
80- * @tparam ChunkM Chunk M dimension size
81- * @tparam ChunkN Chunk N dimension size
76+ * @tparam ADataType Data type of matrix A
77+ * @tparam BDataType Data type of matrix B
78+ * @tparam CDataType Data type of the accumulator
79+ * @tparam WaveTileM WaveTile M dimension size
80+ * @tparam WaveTileN WaveTile N dimension size
8281 * @tparam CompilerTarget The compiler target
8382 */
8483template <typename ADataType,
8584 typename BDataType,
8685 typename CDataType,
87- uint32_t ChunkM ,
88- uint32_t ChunkN ,
86+ uint32_t WaveTileM ,
87+ uint32_t WaveTileN ,
8988 typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
90- struct MfmaDefaultSelector <ADataType, BDataType, CDataType, ChunkM, ChunkN, 1u , CompilerTarget>
89+ struct MfmaDefaultSelector <ADataType,
90+ BDataType,
91+ CDataType,
92+ WaveTileM,
93+ WaveTileN,
94+ 1u ,
95+ CompilerTarget>
9196{
9297 // Default unsupported pass-through if no instruction is found
9398 using SelectedOp =
9499 amdgcn_mma<ADataType,
95100 BDataType,
96101 CDataType,
97- ChunkM ,
98- ChunkN ,
102+ WaveTileM ,
103+ WaveTileN ,
99104 1u ,
100105 DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
101106 CompilerTarget,
@@ -105,32 +110,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, ChunkM, ChunkN, 1u,
105110/* *
106111 * @struct MmaDefaultSelector
107112 * @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
108- * This implements the M/N chunk size search strategy to find the largest supported MFMA
113+ * This implements the M/N WaveTile size search strategy to find the largest supported MFMA
109114 * instruction for the given datatypes.
110115 * If no supported instruction is found, falls back to an unsupported pass-through implementation.
111- * @tparam ADataType Data type of matrix A
112- * @tparam BDataType Data type of matrix B
113- * @tparam CDataType Data type of the accumulator
114- * @tparam ChunkM Size of the M dimension of the chunk to decompose
115- * @tparam ChunkN Size of the N dimension of the chunk to decompose
116- * @tparam ChunkK Size of the K dimension of the chunk to decompose
116+ * @tparam ADataType Data type of matrix A
117+ * @tparam BDataType Data type of matrix B
118+ * @tparam CDataType Data type of the accumulator
119+ * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose
120+ * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose
121+ * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose
117122 * @tparam CompilerTarget The compiler target
118- * @tparam OpFamily The MMA operation family
123+ * @tparam OpFamily The MMA operation family
119124 */
120125template <typename ADataType,
121126 typename BDataType,
122127 typename CDataType,
123- uint32_t ChunkM ,
124- uint32_t ChunkN ,
125- uint32_t ChunkK ,
128+ uint32_t WaveTileM ,
129+ uint32_t WaveTileN ,
130+ uint32_t WaveTileK ,
126131 typename CompilerTarget,
127132 MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
128133struct MmaDefaultSelector <ADataType,
129134 BDataType,
130135 CDataType,
131- ChunkM ,
132- ChunkN ,
133- ChunkK ,
136+ WaveTileM ,
137+ WaveTileN ,
138+ WaveTileK ,
134139 CompilerTarget,
135140 OpFamily,
136141 enable_if_all<enable_if_target_family_gfx9_t <CompilerTarget>,
@@ -162,20 +167,20 @@ struct MmaDefaultSelector<ADataType,
162167 typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u , 1u , 1u , CompilerTarget>::
163168 SelectedOp;
164169
165- // Check if each candidate is supported for the given chunk sizes
166- // For this case, we require the chunk sizes to be multiples of the MFMA shape
170+ // Check if each candidate is supported for the given WaveTile sizes
171+ // For this case, we require the WaveTile sizes to be multiples of the MFMA shape
167172 static constexpr bool IsSupported4x4 =
168- MmaOpTraits<CandidateOp4x4>::IsSupported && (ChunkM % CandidateOp4x4::kM == 0u ) &&
169- (ChunkN % CandidateOp4x4::kN == 0u ) && (ChunkK % CandidateOp4x4::kK == 0u );
173+ MmaOpTraits<CandidateOp4x4>::IsSupported && (WaveTileM % CandidateOp4x4::kM == 0u ) &&
174+ (WaveTileN % CandidateOp4x4::kN == 0u ) && (WaveTileK % CandidateOp4x4::kK == 0u );
170175 static constexpr bool IsSupported16x16 =
171- MmaOpTraits<CandidateOp16x16>::IsSupported && (ChunkM % CandidateOp16x16::kM == 0u ) &&
172- (ChunkN % CandidateOp16x16::kN == 0u ) && (ChunkK % CandidateOp16x16::kK == 0u );
176+ MmaOpTraits<CandidateOp16x16>::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u ) &&
177+ (WaveTileN % CandidateOp16x16::kN == 0u ) && (WaveTileK % CandidateOp16x16::kK == 0u );
173178 static constexpr bool IsSupported32x32 =
174- MmaOpTraits<CandidateOp32x32>::IsSupported && (ChunkM % CandidateOp32x32::kM == 0u ) &&
175- (ChunkN % CandidateOp32x32::kN == 0u ) && (ChunkK % CandidateOp32x32::kK == 0u );
179+ MmaOpTraits<CandidateOp32x32>::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u ) &&
180+ (WaveTileN % CandidateOp32x32::kN == 0u ) && (WaveTileK % CandidateOp32x32::kK == 0u );
176181
177182 public:
178- // Select the largest supported MFMA operation for the given chunk shape
183+ // Select the largest supported MFMA operation for the given WaveTile shape
179184 using SelectedOp = std::conditional_t <
180185 IsSupported32x32,
181186 CandidateOp32x32,
0 commit comments