1414
1515namespace ck_tile ::core::arch::mma {
1616
17- // TODO: Describe layout params.
17+ /* *---------------------------------------------------
18+ * Meaning of amdgcn_mma layout parameters (general)
19+ * ---------------------------------------------------
20+ *
21+ * The fragment sizes and layout constants in the amdgcn_mma struct describe the mapping between
22+ * intrinsic input / output matrix elements and vector registers (lane x vector_item space). Note
23+ * that we end up having a mapping for A, B and C separately, although those for A and B are usually
24+ * similar if not identical. All mappings can be described as an unmerge operation on one of the
25+ * matrix dims (either K for AB or M for C), followed by remerging of the resulting subdims and raw
26+ * other dim into the Lane and Vector Item dimensions. When I consider an unmerge operation on a
27+ * dimension K, I like to label the resulting sub-dimensions as K0, K1, and K2, where K0 is the size
28+ * of the fastest changing dimension. K0 is also referred to as "The size of the first unmerge", and
29+ * K1 would be "The size of the second unmerge". There are never more than 2 unmerge operations, and
30+ * unmerge operations may be trivial (unmerge size of 1). Example double unmerge of size {3, 2} of a
31+ * K dimension of size 12:
32+ *
33+ * K K2 K1 K0
34+ * 0 0 0 0
35+ * 1 0 0 1
36+ * 2 0 1 0
37+ * 3 0 1 1
38+ * 4 0 2 0
39+ * 5 0 2 1
40+ * 6 1 0 0
41+ * 7 1 0 1
42+ * 8 1 1 0
43+ * 9 1 1 1
44+ * 10 1 2 0
45+ * 11 1 2 1
46+ *
47+ * Note that K0 = 2 (first unmerge size, fastest changing), K1 = 3 (second unmerge size,
48+ * second-fastest changing), and K2 = 12 / 2 / 3 = 2 (outermost dimension, whatever is left).
49+ *
50+ * If we were to use this unmerge op to decribe an A matrix layout in registers, we might have for
51+ * example that L (lane dim) is composed of K1 and M, and V (vector item dim) is composed of K2 and
52+ * K0. Compactly described, this would be K{3, 2} L{K1M} V{K2K0}, and if the M dimension was 2 we
53+ * would have the following layout (6 lanes, 4 vector items each):
54+ *
55+ * | V0 | V1 | V2 | V3 |
56+ * L0 | M=0 K=0 | M=0 K=1 | M=0 K=6 | M=0 K=7 |
57+ * L1 | M=1 K=0 | M=1 K=1 | M=1 K=6 | M=1 K=7 |
58+ * L2 | M=0 K=2 | M=0 K=3 | M=0 K=8 | M=0 K=9 |
59+ * L3 | M=1 K=2 | M=1 K=3 | M=1 K=8 | M=1 K=9 |
60+ * L4 | M=0 K=4 | M=0 K=5 | M=0 K=10 | M=0 K=11 |
61+ * L5 | M=1 K=4 | M=1 K=5 | M=1 K=10 | M=1 K=11 |
62+ *
63+ * Note that all A matrix elements are now placed in a unique (lane, vector_item). In case a Repeat
64+ * dimension is used, every single matrix element is mapped to multiple (Lane, Vector_item)
65+ * locations, usually along the Lane dimension.
66+ *
67+ * Check out TileDistrEncRegMap which can print full forward and backward mapping tables for any
68+ * register mapping (expressed as a tile distribution encoding).
69+ *
70+ * ------------------------------------------
71+ * Individual amdgcn_mma layout parameters
72+ * ------------------------------------------
73+ *
74+ * -- ABKPerLane --
75+ * The number of K dim elements in each lane. Always the same for A and B, even when they have
76+ * different layouts. In terms of unmerge sizes, it's equal to K0 * K2, i.e the product of the sizes
77+ * of the outermost and innermost dimensions after a double K unmerge.
78+ *
79+ * -- A / B NumAccess --
80+ * These two variables describe the size of the outermost dimension if two unmerge operations are
81+ * required for K (so K2). Alternatively it can be described as the number of sets the vector
82+ * dimension, which houses a number of K indices, is split up into. We may be able to actually
83+ * remove A and B num access as well, but it sort of depends on how load and store tile work and
84+ * whether we want the user to always have to know about this. There are only two reasons for these
85+ * to ever not be 1, and they are different types of reasons:
86+ *
87+ * (logical correctness). You are dealing with scale MFMA fp8, which due to the index matrix layout
88+ * does not allow arbitrary K perms to simplify layouts. This means the layout can only properly be
89+ * described with a Num Access value of at least 2.
90+ *
91+ * (load / store manipulation). I think the load and store tile functions end up looking for the
92+ * size of the smallest unmerged K dimension (K0) to determine how many elements should be loaded at
93+ * a time. Different Num Access values will lead to different load / store behavior, even if
94+ * logically equivalent.
95+ *
96+ * -- A / B Repeat --
97+ * Variable indicating that all matrix values are represented multiple times in the vector
98+ * reigsters, typically repeating in the lane dimension. This is always equal to the repeat value
99+ * used in Tile Distribution encodings. There are two reasons to have non-trivial (non-1) value
100+ * here: MFMA block-hiding to create oblong "virtual" intrinsics, and RDNA3 input repetition.
101+ *
102+ * -- CMPerLane --
103+ * The number of M dim elements in each lane. In terms of unmerge sizes, is equal to M0 * M2, i.e
104+ * the product of the sizes of the outermost and innermost dimensions after a double M unmerge.
105+ *
106+ * -- CNumAccess --
107+ * Same as A / B NumAccess but for the M dim (so M2), but the mid-level code doesn't care about this
108+ * and will not try to request a specific value. Absolutely needed for logical correctness of
109+ * register mappings since we can not perform arbitrary M permutations without messing up the A
110+ * layout.
111+ */
112+
18113/* *
19114 * @class amdgcn_mma_base
20115 * @brief Base class for amdgcn_mma structs to avoid a lot of code duplication. Also puts
@@ -47,19 +142,19 @@ struct amdgcn_mma_base
47142 using BDataType = BDataType_;
48143 using CDataType = CDataType_;
49144
50- // Fragment sizes
51- static constexpr index_t kM = FragM;
145+ // Fragment sizes, check description above.
146+ static constexpr index_t kM = FragM; // M = M2 * M1 * M0
52147 static constexpr index_t kN = FragN;
53- static constexpr index_t kK = FragK;
148+ static constexpr index_t kK = FragK; // K = K2 * K1 * K0
54149
55- // Layout constants
56- static constexpr index_t kABKPerLane = kABKPerLane_ ;
57- static constexpr index_t kAKNumAccess = kAKNumAccess_ ;
58- static constexpr index_t kARepeat = kARepeat_ ;
59- static constexpr index_t kBKNumAccess = kBKNumAccess_ ;
60- static constexpr index_t kBRepeat = kBRepeat_ ;
61- static constexpr index_t kCMPerLane = kCMPerLane_ ;
62- static constexpr index_t kCMNumAccess = kCMNumAccess_ ;
150+ // Layout constants, check description above.
151+ static constexpr index_t kABKPerLane = kABKPerLane_ ; // K2 * K0
152+ static constexpr index_t kAKNumAccess = kAKNumAccess_ ; // K2
153+ static constexpr index_t kARepeat = kARepeat_ ; // RDNA3 repetition and MFMA block-hiding
154+ static constexpr index_t kBKNumAccess = kBKNumAccess_ ; // K2
155+ static constexpr index_t kBRepeat = kBRepeat_ ; // RDNA3 repetition and MFMA block-hiding
156+ static constexpr index_t kCMPerLane = kCMPerLane_ ; // M2 * M0
157+ static constexpr index_t kCMNumAccess = kCMNumAccess_ ; // M2
63158
64159 // Register types (derived)
65160 static constexpr index_t WaveSize = WaveSize_;
@@ -132,6 +227,7 @@ concept MmaOpI = requires(MmaOp op) {
132227 * @tparam FragK K-dimension of mma intrinsic
133228 * @tparam CtrlFlags Control flags for mma operation
134229 * @tparam CompilerTarget The current compiler target
230+ * @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.)
135231 * @tparam Enabler SFINAE enabler
136232 */
137233template <typename ADataType,
@@ -145,6 +241,7 @@ template <typename ADataType,
145241 MmaOpFamily OpFamily_,
146242 typename Enabler = void >
147243// clang-format off
244+ // | A B C DataTypes |MNK + WaveSize |AParams |BPar |CPar |
148245struct amdgcn_mma : amdgcn_mma_base<fp32_t , fp32_t , fp32_t , 1u , 1u , 1u , 1u , 1 , 1 , 1 , 1 , 1 , 1 , 1 , Unsupported, MmaOpFamily::UNDEFINED>
149246// clang-format on
150247{
0 commit comments