@@ -18,28 +18,28 @@ namespace ck_tile::core::arch::mma {
1818 * @class MfmaDefaultSelector
1919 * @brief Implements a default MFMA selector strategy for gfx9 target architectures.
2020 * This implements the K dimension search strategy to find the largest supported MFMA
21- * instruction for the given M/N block sizes and datatypes.
21+ * instruction for the given M/N chunk sizes and datatypes.
2222 * If no supported instruction is found, falls back to an unsupported pass-through
2323 implementation.
2424 * @tparam ADataType Data type of matrix A
2525 * @tparam BDataType Data type of matrix B
2626 * @tparam CDataType Data type of the accumulator
27- * @tparam FragM Block M dimension size
28- * @tparam FragN Block N dimension size
29- * @tparam FragKTest Current Block K dimension size to test
27+ * @tparam ChunkM Chunk M dimension size
28+ * @tparam ChunkN Chunk N dimension size
29+ * @tparam ChunkKTest Current Chunk K dimension size to test
3030 * @tparam CompilerTarget The compiler target
31- * @note Here we assume that FragKTest is always a power-of-two integer.
32- * The search strategy starts from a maximum FragKTest size down to 1u by halving
31+ * @note Here we assume that ChunkKTest is always a power-of-two integer.
32+ * The search strategy starts from a maximum ChunkKTest size down to 1u by halving
3333 * each time.
3434 */
3535template <typename ADataType,
3636 typename BDataType,
3737 typename CDataType,
38- uint32_t FragM ,
39- uint32_t FragN ,
40- uint32_t FragKTest ,
38+ uint32_t ChunkM ,
39+ uint32_t ChunkN ,
40+ uint32_t ChunkKTest ,
4141 typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
42- // TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(FragKTest ))
42+ // TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(ChunkKTest ))
4343struct MfmaDefaultSelector
4444{
4545 private:
@@ -48,25 +48,25 @@ struct MfmaDefaultSelector
4848 amdgcn_mma<ADataType,
4949 BDataType,
5050 CDataType,
51- FragM ,
52- FragN ,
53- FragKTest ,
51+ ChunkM ,
52+ ChunkN ,
53+ ChunkKTest ,
5454 DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
5555 CompilerTarget,
5656 MmaOpFamily::DENSE>;
5757
5858 public:
5959 // If the candidate is supported (e.g., a backend implementation exists), then select it.
60- // Otherwise, test another smaller FragK . If no existing implementations, we will get FragK =0u
60+ // Otherwise, test another smaller ChunkK . If no existing implementations, we will get ChunkK =0u
6161 // and fall back to the unsupported pass-through implementation.
6262 using SelectedOp = std::conditional_t <MmaOpTraits<CandidateOp>::IsSupported,
6363 CandidateOp,
6464 typename MfmaDefaultSelector<ADataType,
6565 BDataType,
6666 CDataType,
67- FragM ,
68- FragN ,
69- FragKTest / 2u ,
67+ ChunkM ,
68+ ChunkN ,
69+ ChunkKTest / 2u ,
7070 CompilerTarget>::SelectedOp>;
7171};
7272
@@ -77,25 +77,25 @@ struct MfmaDefaultSelector
7777 * @tparam ADataType Data type of matrix A
7878 * @tparam BDataType Data type of matrix B
7979 * @tparam CDataType Data type of the accumulator
80- * @tparam FragM Block M dimension size
81- * @tparam FragN Block N dimension size
80+ * @tparam ChunkM Chunk M dimension size
81+ * @tparam ChunkN Chunk N dimension size
8282 * @tparam CompilerTarget The compiler target
8383 */
8484template <typename ADataType,
8585 typename BDataType,
8686 typename CDataType,
87- uint32_t FragM ,
88- uint32_t FragN ,
87+ uint32_t ChunkM ,
88+ uint32_t ChunkN ,
8989 typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
90- struct MfmaDefaultSelector <ADataType, BDataType, CDataType, FragM, FragN , 1u , CompilerTarget>
90+ struct MfmaDefaultSelector <ADataType, BDataType, CDataType, ChunkM, ChunkN , 1u , CompilerTarget>
9191{
9292 // Default unsupported pass-through if no instruction is found
9393 using SelectedOp =
9494 amdgcn_mma<ADataType,
9595 BDataType,
9696 CDataType,
97- FragM ,
98- FragN ,
97+ ChunkM ,
98+ ChunkN ,
9999 1u ,
100100 DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
101101 CompilerTarget,
@@ -105,32 +105,32 @@ struct MfmaDefaultSelector<ADataType, BDataType, CDataType, FragM, FragN, 1u, Co
105105/* *
106106 * @struct MmaDefaultSelector
107107 * @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
108- * This implements the M/N block size search strategy to find the largest supported MFMA
108+ * This implements the M/N chunk size search strategy to find the largest supported MFMA
109109 * instruction for the given datatypes.
110110 * If no supported instruction is found, falls back to an unsupported pass-through implementation.
111111 * @tparam ADataType Data type of matrix A
112112 * @tparam BDataType Data type of matrix B
113113 * @tparam CDataType Data type of the accumulator
114- * @tparam FragM Size of the M dimension of the fragment to decompose
115- * @tparam FragN Size of the N dimension of the fragment to decompose
116- * @tparam FragK Size of the K dimension of the fragment to decompose
114+ * @tparam ChunkM Size of the M dimension of the chunk to decompose
115+ * @tparam ChunkN Size of the N dimension of the chunk to decompose
116+ * @tparam ChunkK Size of the K dimension of the chunk to decompose
117117 * @tparam CompilerTarget The compiler target
118118 * @tparam OpFamily The MMA operation family
119119 */
120120template <typename ADataType,
121121 typename BDataType,
122122 typename CDataType,
123- uint32_t FragM ,
124- uint32_t FragN ,
125- uint32_t FragK ,
123+ uint32_t ChunkM ,
124+ uint32_t ChunkN ,
125+ uint32_t ChunkK ,
126126 typename CompilerTarget,
127127 MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
128128struct MmaDefaultSelector <ADataType,
129129 BDataType,
130130 CDataType,
131- FragM ,
132- FragN ,
133- FragK ,
131+ ChunkM ,
132+ ChunkN ,
133+ ChunkK ,
134134 CompilerTarget,
135135 OpFamily,
136136 enable_if_all<enable_if_target_family_gfx9_t <CompilerTarget>,
@@ -162,23 +162,20 @@ struct MmaDefaultSelector<ADataType,
162162 typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u , 1u , 1u , CompilerTarget>::
163163 SelectedOp;
164164
165- // Check if each candidate is supported for the given fragment sizes
166- // For this case, we require the fragment sizes to be multiples of the MFMA shape
167- static constexpr bool IsSupported4x4 = MmaOpTraits<CandidateOp4x4>::IsSupported &&
168- (FragM % CandidateOp4x4::kM == 0u ) &&
169- (FragN % CandidateOp4x4::kN == 0u ) &&
170- (FragK % CandidateOp4x4::kK == 0u );
171- static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
172- (FragM % CandidateOp16x16::kM == 0u ) &&
173- (FragN % CandidateOp16x16::kN == 0u ) &&
174- (FragK % CandidateOp16x16::kK == 0u );
175- static constexpr bool IsSupported32x32 = MmaOpTraits<CandidateOp32x32>::IsSupported &&
176- (FragM % CandidateOp32x32::kM == 0u ) &&
177- (FragN % CandidateOp32x32::kN == 0u ) &&
178- (FragK % CandidateOp32x32::kK == 0u );
165+ // Check if each candidate is supported for the given chunk sizes
166+ // For this case, we require the chunk sizes to be multiples of the MFMA shape
167+ static constexpr bool IsSupported4x4 =
168+ MmaOpTraits<CandidateOp4x4>::IsSupported && (ChunkM % CandidateOp4x4::kM == 0u ) &&
169+ (ChunkN % CandidateOp4x4::kN == 0u ) && (ChunkK % CandidateOp4x4::kK == 0u );
170+ static constexpr bool IsSupported16x16 =
171+ MmaOpTraits<CandidateOp16x16>::IsSupported && (ChunkM % CandidateOp16x16::kM == 0u ) &&
172+ (ChunkN % CandidateOp16x16::kN == 0u ) && (ChunkK % CandidateOp16x16::kK == 0u );
173+ static constexpr bool IsSupported32x32 =
174+ MmaOpTraits<CandidateOp32x32>::IsSupported && (ChunkM % CandidateOp32x32::kM == 0u ) &&
175+ (ChunkN % CandidateOp32x32::kN == 0u ) && (ChunkK % CandidateOp32x32::kK == 0u );
179176
180177 public:
181- // Select the largest supported MFMA operation for the given fragment shape
178+ // Select the largest supported MFMA operation for the given chunk shape
182179 using SelectedOp = std::conditional_t <
183180 IsSupported32x32,
184181 CandidateOp32x32,
0 commit comments