Skip to content

Commit b37eacf

Browse files
committed
Simplify MmaOpTraits now that almost everything is directly available from MmaOp.
1 parent 077d4b2 commit b37eacf

File tree

10 files changed

+121
-228
lines changed

10 files changed

+121
-228
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ struct MfmaDefaultSelector
5454
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
5555
CompilerTarget,
5656
MmaOpFamily::DENSE>;
57-
using CandidateTraits = MmaOpTraits<CandidateOp>;
5857

5958
public:
6059
// If the candidate is supported (e.g., a backend implementation exists), then select it.
6160
// Otherwise, test another smaller FragK. If no existing implementations, we will get FragK=0u
6261
// and fall back to the unsupported pass-through implementation.
63-
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
62+
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
6463
CandidateOp,
6564
typename MfmaDefaultSelector<ADataType,
6665
BDataType,
@@ -163,25 +162,20 @@ struct MmaDefaultSelector<ADataType,
163162
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
164163
SelectedOp;
165164

166-
// Traits for each candidate
167-
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
168-
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
169-
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
170-
171165
// Check if each candidate is supported for the given fragment sizes
172166
// For this case, we require the fragment sizes to be multiples of the MFMA shape
173-
static constexpr bool IsSupported4x4 = CandidateTraits4x4::IsSupported &&
174-
(FragM % CandidateTraits4x4::FragM == 0u) &&
175-
(FragN % CandidateTraits4x4::FragN == 0u) &&
176-
(FragK % CandidateTraits4x4::FragK == 0u);
177-
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
178-
(FragM % CandidateTraits16x16::FragM == 0u) &&
179-
(FragN % CandidateTraits16x16::FragN == 0u) &&
180-
(FragK % CandidateTraits16x16::FragK == 0u);
181-
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
182-
(FragM % CandidateTraits32x32::FragM == 0u) &&
183-
(FragN % CandidateTraits32x32::FragN == 0u) &&
184-
(FragK % CandidateTraits32x32::FragK == 0u);
167+
static constexpr bool IsSupported4x4 = MmaOpTraits<CandidateOp4x4>::IsSupported &&
168+
(FragM % CandidateOp4x4::kM == 0u) &&
169+
(FragN % CandidateOp4x4::kN == 0u) &&
170+
(FragK % CandidateOp4x4::kK == 0u);
171+
static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
172+
(FragM % CandidateOp16x16::kM == 0u) &&
173+
(FragN % CandidateOp16x16::kN == 0u) &&
174+
(FragK % CandidateOp16x16::kK == 0u);
175+
static constexpr bool IsSupported32x32 = MmaOpTraits<CandidateOp32x32>::IsSupported &&
176+
(FragM % CandidateOp32x32::kM == 0u) &&
177+
(FragN % CandidateOp32x32::kN == 0u) &&
178+
(FragK % CandidateOp32x32::kK == 0u);
185179

186180
public:
187181
// Select the largest supported MFMA operation for the given fragment shape

projects/composablekernel/include/ck_tile/core/arch/mma/mma.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#include "amdgcn_mma.hpp"
88
#include "mma_selector.hpp"
9-
#include "mma_traits.hpp"
109
#include "mma_transforms.hpp"
1110

1211
#include "mfma/mfma.hpp"
@@ -76,14 +75,12 @@ template <typename ADataType,
7675
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
7776
struct WaveWiseMma
7877
{
79-
80-
using BlockWiseMmaOp = MmaOp;
81-
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
78+
using BlockWiseMmaOp = MmaOp;
8279

8380
// Block dimensions
84-
constexpr static uint32_t FragM = BlockWiseMmaOpTraits::FragM;
85-
constexpr static uint32_t FragN = BlockWiseMmaOpTraits::FragN;
86-
constexpr static uint32_t FragK = BlockWiseMmaOpTraits::FragK;
81+
constexpr static uint32_t FragM = MmaOp::kM;
82+
constexpr static uint32_t FragN = MmaOp::kN;
83+
constexpr static uint32_t FragK = MmaOp::kK;
8784

8885
// Block counts for decomposition
8986
constexpr static uint32_t BlocksM = ChunkM / FragM;
@@ -92,9 +89,9 @@ struct WaveWiseMma
9289
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
9390

9491
// Vector types for packed registers in each block
95-
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
96-
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
97-
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
92+
using AVecType = typename MmaOp::AVecType;
93+
using BVecType = typename MmaOp::BVecType;
94+
using CVecType = typename MmaOp::CVecType;
9895

9996
// Buffer types for chunks
10097
using ABufferType = AVecType[BlocksM][BlocksK];

projects/composablekernel/include/ck_tile/core/arch/mma/mma_traits.hpp

Lines changed: 22 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -43,45 +43,12 @@ struct is_mma_op_supported<MmaOp,
4343
template <typename MmaOp>
4444
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
4545

46-
/**
47-
* @class MmaOpParams
48-
* @brief Reflects the template parameters of a given MmaOp
49-
* @tparam MmaOp The matrix multiply-accumulate operation type to check
50-
*/
51-
// TODO: c++20 template <MmaOpI MmaOp>
52-
template <typename MmaOp>
53-
struct MmaOpParams;
54-
55-
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
56-
#include <concepts>
57-
58-
// TODO: update concept with all params.
59-
/**
60-
* @concept MmaOpParamsI
61-
* @brief Expresses the required members for each MmaOp
62-
*/
63-
template <typename MmaOpParams>
64-
concept MmaOpParamsI = requires(MmaOpParams op) {
65-
// Capture template parameters
66-
typename MmaOpParams::ADataType;
67-
typename MmaOpParams::BDataType;
68-
typename MmaOpParams::CDataType;
69-
typename MmaOpParams::CtrlFlags;
70-
71-
{ MmaOpParams::FragM } -> std::convertible_to<unsigned int>;
72-
{ MmaOpParams::FragN } -> std::convertible_to<unsigned int>;
73-
{ MmaOpParams::FragK } -> std::convertible_to<unsigned int>;
74-
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
75-
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
76-
};
46+
template <typename T>
47+
struct MmaOpTraits;
7748

78-
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
79-
80-
// TODO: Figure out a way to deal with all the repetition in amdgcn structs and Params / Traits
81-
// structs
8249
/**
83-
* @struct MmaOpParams
84-
* @brief Reflects the template parameters of a given MmaOp
50+
* @struct MmaOpTraits
51+
* @brief Gives additional traits and unexposed template parameters of a given MmaOp
8552
* @tparam ADataType_ Data type of matrix A
8653
* @tparam BDataType_ Data type of matrix B
8754
* @tparam CDataType_ Data type of the accumulator
@@ -101,7 +68,7 @@ template <typename ADataType_,
10168
typename CompilerTarget_,
10269
MmaOpFamily OpFamily_>
10370
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
104-
struct MmaOpParams<amdgcn_mma<ADataType_,
71+
struct MmaOpTraits<amdgcn_mma<ADataType_,
10572
BDataType_,
10673
CDataType_,
10774
FragM_,
@@ -111,60 +78,29 @@ struct MmaOpParams<amdgcn_mma<ADataType_,
11178
CompilerTarget_,
11279
OpFamily_>>
11380
{
114-
// Capture incoming template parameters
115-
using ADataType = ADataType_;
116-
using BDataType = BDataType_;
117-
using CDataType = CDataType_;
118-
static constexpr uint32_t FragM = FragM_;
119-
static constexpr uint32_t FragN = FragN_;
120-
static constexpr uint32_t FragK = FragK_;
121-
using CtrlFlags = CtrlFlags_;
122-
using CompilerTarget = CompilerTarget_;
123-
static constexpr auto MmaOpFamily = OpFamily_;
81+
using MmaOp = amdgcn_mma<ADataType_,
82+
BDataType_,
83+
CDataType_,
84+
FragM_,
85+
FragN_,
86+
FragK_,
87+
CtrlFlags_,
88+
CompilerTarget_,
89+
OpFamily_>;
90+
91+
// Capture incoming template parameters not already in amdgcn
92+
using CtrlFlags = CtrlFlags_;
93+
using CompilerTarget = CompilerTarget_;
12494
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
125-
};
126-
127-
/**
128-
* @class MmaOpTraits
129-
* @brief Reflects the template parameters and static members of a given MmaOp.
130-
* @tparam MmaOp The matrix multiply-accumulate operation
131-
*/
132-
template <typename MmaOp>
133-
// TODO: c++20 template <MmaOpI MmaOp>
134-
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
135-
struct MmaOpTraits : public MmaOpParams<MmaOp>
136-
{
137-
// Capture internal MmaOp static members
138-
using OpType = typename MmaOp::OpType;
139-
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
140-
141-
// Capture fragment sizes
142-
static constexpr index_t kM = MmaOp::kM;
143-
static constexpr index_t kN = MmaOp::kN;
144-
static constexpr index_t kK = MmaOp::kK;
145-
146-
// Capture layout parameters
147-
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
148-
static constexpr index_t kAKNumAccess = MmaOp::kAKNumAccess;
149-
static constexpr index_t kARepeat = MmaOp::kARepeat;
150-
static constexpr index_t kBKNumAccess = MmaOp::kBKNumAccess;
151-
static constexpr index_t kBRepeat = MmaOp::kBRepeat;
152-
static constexpr index_t kCMPerLane = MmaOp::kCMPerLane;
153-
static constexpr index_t kCMNumAccess = MmaOp::kCMNumAccess;
154-
155-
// Capture register types
156-
using AVecType = typename MmaOp::AVecType;
157-
using BVecType = typename MmaOp::BVecType;
158-
using CVecType = typename MmaOp::CVecType;
15995

16096
// Additional traits to identify the type of MmaOp at compile time
16197
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
16298
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
163-
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
164-
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
165-
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
99+
constexpr static bool IsDense = OpFamily_ == MmaOpFamily::DENSE;
100+
constexpr static bool IsSparse = OpFamily_ == MmaOpFamily::SPARSE;
101+
constexpr static bool IsScale = OpFamily_ == MmaOpFamily::SCALE;
166102
constexpr static bool IsSupported =
167-
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
103+
is_mma_op_supported_v<MmaOp> && OpFamily_ != MmaOpFamily::UNDEFINED;
168104
};
169105

170106
} // namespace ck_tile::core::arch::mma

projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,10 @@ struct SparseMfmaDefaultSelector
4444
CompilerTarget,
4545
MmaOpFamily::SPARSE>;
4646

47-
using CandidateTraits = MmaOpTraits<CandidateOp>;
48-
4947
public:
5048
// If the candidate is supported (e.g., a backend implementation exists), then select it.
5149
// Otherwise, fall back to the unsupported pass-through implementation.
52-
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
50+
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
5351
CandidateOp,
5452
amdgcn_mma<ADataType,
5553
BDataType,
@@ -125,20 +123,16 @@ struct MmaDefaultSelector<ADataType,
125123
1u,
126124
CompilerTarget>::SelectedOp;
127125

128-
// Traits for each candidate
129-
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
130-
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
131-
132126
// Check if each candidate is supported for the given fragment sizes
133127
// For this case, we require the fragment sizes to be multiples of the MFMA shape
134-
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
135-
(FragM % CandidateTraits16x16::FragM == 0u) &&
136-
(FragN % CandidateTraits16x16::FragN == 0u) &&
137-
(FragK % CandidateTraits16x16::FragK == 0u);
138-
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
139-
(FragM % CandidateTraits32x32::FragM == 0u) &&
140-
(FragN % CandidateTraits32x32::FragN == 0u) &&
141-
(FragK % CandidateTraits32x32::FragK == 0u);
128+
static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
129+
(FragM % CandidateOp16x16::kM == 0u) &&
130+
(FragN % CandidateOp16x16::kN == 0u) &&
131+
(FragK % CandidateOp16x16::kK == 0u);
132+
static constexpr bool IsSupported32x32 = MmaOpTraits<CandidateOp32x32>::IsSupported &&
133+
(FragM % CandidateOp32x32::kM == 0u) &&
134+
(FragN % CandidateOp32x32::kN == 0u) &&
135+
(FragK % CandidateOp32x32::kK == 0u);
142136

143137
public:
144138
// Select the largest supported MFMA operation for the given fragment shape

projects/composablekernel/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,12 @@ struct MmaDefaultSelector<ADataType,
116116
1u,
117117
CompilerTarget>::SelectedOp;
118118

119-
// Traits for each candidate
120-
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
121-
122119
// Check if each candidate is supported for the given fragment sizes
123120
// For this case, we require the fragment sizes to be multiples of the WMMA shape
124-
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
125-
(FragM % CandidateTraits16x16::FragM == 0u) &&
126-
(FragN % CandidateTraits16x16::FragN == 0u) &&
127-
(FragK % CandidateTraits16x16::FragK == 0u);
121+
static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
122+
(FragM % CandidateOp16x16::kM == 0u) &&
123+
(FragN % CandidateOp16x16::kN == 0u) &&
124+
(FragK % CandidateOp16x16::kK == 0u);
128125

129126
public:
130127
// Select the largest supported WMMA operation for the given fragment shape

projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,11 @@ struct WmmaDefaultSelector
4949
CompilerTarget,
5050
MmaOpFamily::DENSE>;
5151

52-
using CandidateTraits = MmaOpTraits<CandidateOp>;
53-
5452
public:
5553
// If the candidate is supported (e.g., a backend implementation exists), then select it.
5654
// Otherwise, test another smaller FragK. If no existing implementations, we will get FragK=0u
5755
// and fall back to the unsupported pass-through implementation.
58-
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
56+
using SelectedOp = std::conditional_t<MmaOpTraits<CandidateOp>::IsSupported,
5957
CandidateOp,
6058
typename WmmaDefaultSelector<ADataType,
6159
BDataType,
@@ -155,15 +153,12 @@ struct MmaDefaultSelector<ADataType,
155153
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
156154
SelectedOp;
157155

158-
// Traits for each candidate
159-
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
160-
161156
// Check if each candidate is supported for the given fragment sizes
162157
// For this case, we require the fragment sizes to be multiples of the WMMA shape
163-
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
164-
(FragM % CandidateTraits16x16::FragM == 0u) &&
165-
(FragN % CandidateTraits16x16::FragN == 0u) &&
166-
(FragK % CandidateTraits16x16::FragK == 0u);
158+
static constexpr bool IsSupported16x16 = MmaOpTraits<CandidateOp16x16>::IsSupported &&
159+
(FragM % CandidateOp16x16::kM == 0u) &&
160+
(FragN % CandidateOp16x16::kN == 0u) &&
161+
(FragK % CandidateOp16x16::kK == 0u);
167162

168163
public:
169164
// Select the largest supported WMMA operation for the given fragment shape

0 commit comments

Comments
 (0)