Skip to content

Commit 0e26f69

Browse files
committed
Add detailed layout parameter descriptions.
1 parent 7a1ca1a commit 0e26f69

File tree

6 files changed

+115
-12
lines changed

6 files changed

+115
-12
lines changed

projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,102 @@
1414

1515
namespace ck_tile::core::arch::mma {
1616

17-
// TODO: Describe layout params.
17+
/**---------------------------------------------------
18+
* Meaning of amdgcn_mma layout parameters (general)
19+
* ---------------------------------------------------
20+
*
21+
* The fragment sizes and layout constants in the amdgcn_mma struct describe the mapping between
22+
* intrinsic input / output matrix elements and vector registers (lane x vector_item space). Note
23+
* that we end up having a mapping for A, B and C separately, although those for A and B are usually
24+
* similar if not identical. All mappings can be described as an unmerge operation on one of the
25+
* matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and raw
26+
* other dim into the Lane and Vector Item dimensions. When I consider an unmerge operation on a
27+
* dimension K, I like to label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
28+
* of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
29+
* K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
30+
* unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
31+
* K dimension of size 12:
32+
*
33+
* K K2 K1 K0
34+
* 0 0 0 0
35+
* 1 0 0 1
36+
* 2 0 1 0
37+
* 3 0 1 1
38+
* 4 0 2 0
39+
* 5 0 2 1
40+
* 6 1 0 0
41+
* 7 1 0 1
42+
* 8 1 1 0
43+
* 9 1 1 1
44+
* 10 1 2 0
45+
* 11 1 2 1
46+
*
47+
* Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
48+
* second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
49+
*
50+
* If we were to use this unmerge op to decribe an A matrix layout in registers, we might have for
51+
* example that L (lane dim) is composed of K1 and M, and V (vector item dim) is composed of K2 and
52+
* K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we
53+
* would have the following layout (6 lanes, 4 vector items each):
54+
*
55+
* | V0 | V1 | V2 | V3 |
56+
* L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 |
57+
* L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 |
58+
* L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 |
59+
* L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 |
60+
* L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 |
61+
* L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 |
62+
*
63+
* Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat
64+
* dimension is used, every single matrix element is mapped to multiple (Lane, Vector_item)
65+
* locations, usually along the Lane dimension.
66+
*
67+
* Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any
68+
* register mapping (expressed as a tile distribution encoding).
69+
*
70+
* ------------------------------------------
71+
* Individual amdgcn_mma layout parameters
72+
* ------------------------------------------
73+
*
74+
* -- ABKPerLane --
75+
* The number of K dim elements in each lane. Always the same for A and B, even when they have
76+
* different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes
77+
* of the outermost and innermost dimensions after a double K unmerge.
78+
*
79+
* -- A / B NumAccess --
80+
* These two variables describe the size of the outermost dimension if two unmerge operations are
81+
* required for K (so K2). Alternatively it can be described as the number of sets the vector
82+
* dimension, which houses a number of K indices, is split up into. We may be able to actually
83+
* remove A and B num access as well, but it sort of depends on how load and store tile work and
84+
* whether we want the user to always have to know about this. There are only two reasons for these
85+
* to ever not be 1, and they are different types of reasons:
86+
*
87+
* (logical correctness). You are dealing with scale MFMA fp8, which due to the index matrix layout
88+
* does not allow arbitrary K perms to simplify layouts. This means the layout can only properly be
89+
* described with a Num Access value of at least 2.
90+
*
91+
* (load / store manipulation). I think the load and store tile functions end up looking for the
92+
* size of the smallest unmerged K dimension (K0) to determine how many elements should be loaded at
93+
* a time. Different Num Access values will lead to different load / store behavior, even if
94+
* logically equivalent.
95+
*
96+
* -- A / B Repeat --
97+
* Variable indicating that all matrix values are represented multiple times in the vector
98+
* reigsters, typically repeating in the lane dimension. This is always equal to the repeat value
99+
* used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
100+
* here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
101+
*
102+
* -- CMPerLane --
103+
* The number of M dim elements in each lane. In terms of unmerge sizes, is equal to M0 * M2, i.e
104+
* the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
105+
*
106+
* -- CNumAccess --
107+
* Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
108+
* and will not try to request a specific value. Absolutely needed for logical correctness of
109+
* register mappings since we can not perform arbitrary M permutations without messing up the A
110+
* layout.
111+
*/
112+
18113
/**
19114
* @class amdgcn_mma_base
20115
* @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
@@ -47,19 +142,19 @@ struct amdgcn_mma_base
47142
using BDataType = BDataType_;
48143
using CDataType = CDataType_;
49144

50-
// Fragment sizes
51-
static constexpr index_t kM = FragM;
145+
// Fragment sizes, check description above.
146+
static constexpr index_t kM = FragM; // M = M2 * M1 * M0
52147
static constexpr index_t kN = FragN;
53-
static constexpr index_t kK = FragK;
148+
static constexpr index_t kK = FragK; // K = K2 * K1 * K0
54149

55-
// Layout constants
56-
static constexpr index_t kABKPerLane = kABKPerLane_;
57-
static constexpr index_t kAKNumAccess = kAKNumAccess_;
58-
static constexpr index_t kARepeat = kARepeat_;
59-
static constexpr index_t kBKNumAccess = kBKNumAccess_;
60-
static constexpr index_t kBRepeat = kBRepeat_;
61-
static constexpr index_t kCMPerLane = kCMPerLane_;
62-
static constexpr index_t kCMNumAccess = kCMNumAccess_;
150+
// Layout constants, check description above.
151+
static constexpr index_t kABKPerLane = kABKPerLane_; // K2 * K0
152+
static constexpr index_t kAKNumAccess = kAKNumAccess_; // K2
153+
static constexpr index_t kARepeat = kARepeat_; // RDNA3 repetition and MFMA block-hiding
154+
static constexpr index_t kBKNumAccess = kBKNumAccess_; // K2
155+
static constexpr index_t kBRepeat = kBRepeat_; // RDNA3 repetition and MFMA block-hiding
156+
static constexpr index_t kCMPerLane = kCMPerLane_; // M2 * M0
157+
static constexpr index_t kCMNumAccess = kCMNumAccess_; // M2
63158

64159
// Register types (derived)
65160
static constexpr index_t WaveSize = WaveSize_;
@@ -132,6 +227,7 @@ concept MmaOpI = requires(MmaOp op) {
132227
* @tparam FragK K-dimension of mma intrinsic
133228
* @tparam CtrlFlags Control flags for mma operation
134229
* @tparam CompilerTarget The current compiler target
230+
* @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
135231
* @tparam Enabler SFINAE enabler
136232
*/
137233
template <typename ADataType,
@@ -145,6 +241,7 @@ template <typename ADataType,
145241
MmaOpFamily OpFamily_,
146242
typename Enabler = void>
147243
// clang-format off
244+
// | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar |
148245
struct amdgcn_mma : amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 1u, 1u, 1u, 1u, 1, 1, 1, 1, 1, 1, 1, Unsupported, MmaOpFamily::UNDEFINED>
149246
// clang-format on
150247
{

projects/composablekernel/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
6262
// TODO: c++20 requires
6363
template <typename CtrlFlags, typename CompilerTarget>
6464
// clang-format off
65+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
6566
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx9_t<CompilerTarget>>
6667
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 64u, 4, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
6768
// clang-format on
@@ -92,6 +93,7 @@ struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarg
9293
// TODO: c++20 requires
9394
template <typename CtrlFlags, typename CompilerTarget>
9495
// clang-format off
96+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
9597
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
9698
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::DENSE>
9799
// clang-format on

projects/composablekernel/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
4949
// TODO: c++20 requires
5050
template <typename CtrlFlags, typename CompilerTarget>
5151
// clang-format off
52+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
5253
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
5354
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 64u, 8, 1, 1, 1, 1, 4, 1, MfmaOp, MmaOpFamily::SPARSE>
5455
// clang-format on

projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ struct DefaultWmmaCtrlFlags
7171
// TODO: c++20 requires
7272
template <typename CtrlFlags, typename CompilerTarget>
7373
// clang-format off
74+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
7475
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
7576
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 16, 1, 2, 1, 2, 8, 8, WmmaOp, MmaOpFamily::DENSE>
7677
// clang-format on

projects/composablekernel/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace ck_tile::core::arch::mma {
3131
// TODO: c++20 requires
3232
template <typename CtrlFlags, typename CompilerTarget>
3333
// clang-format off
34+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
3435
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, CtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_family_gfx12_t<CompilerTarget>>
3536
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 16u, 32u, 8, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::DENSE>
3637
// clang-format on

projects/composablekernel/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ using enable_if_target_id_dummy_t = std::enable_if_t<is_dummy_target(CompilerTar
4242
// TODO: c++20 template <amdgcn_target_arch_id CompilerTarget>
4343
template <typename CompilerTarget>
4444
// clang-format off
45+
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
4546
struct amdgcn_mma<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, DummyCtrlFlags, CompilerTarget, MmaOpFamily::DENSE, enable_if_target_id_dummy_t<CompilerTarget>>
4647
: amdgcn_mma_base<fp32_t, fp32_t, fp32_t, 8u, 8u, 8u, 64u, 1, 1, 1, 1, 1, 1, 1, DummyOpType, MmaOpFamily::DENSE>
4748
// clang-format on

0 commit comments

Comments
 (0)