2020
2121namespace mlir ::iree_compiler::IREE::GPU {
2222
23- // Struct describing the detailed subgroup-level layout of a MMA intrinsic.
24- // Together with element type information and subgroup size, it completes the
25- // full description of the semantics of a MMA intrinsic.
26- //
27- // All of these fields hold two values corresponding to the two dimensions of a
28- // matrix in order. That is, a SingleSubgroupLayout for MmaFragment::LHS (matrix
29- // A) holds the values for the M and K dimension in indices 0 and 1 of each
30- // component, respectively, the RHS fragment is K x N, and the Acc fragment (for
31- // accululators) is M x N.
32- //
33- // Note: It is not possible to infer subgroup size from the information in this
34- // struct. The product of the `thread` sizes here is often, but not always equal
35- // to subgroup size. When the product of the `thread` sizes (call that product
36- // `P`) is smaller than subgroup size, it must be a divisor of it, and the
37- // semantics in that case are that threads within the subgroup whose thread-ids
38- // differ by a multiple of `P`, are accessing the same elements.
39- //
40- // Example observed in RDNA3 WMMA Wave32 intrinsics:
41- // If the subgroup size is 32 but the product `P` of `thread` sizes is 16, that
42- // means that each element is being accessed by 2 threads (2 = 32/16), and the
43- // threads accessing the same element are those whose tids are exactly 16 apart.
23+ // ////////////////////////////////////////////////////////////////////////////
24+ // MMASingleSubgroupLayout
25+ // ////////////////////////////////////////////////////////////////////////////
26+ //
27+ // Overview, terminology
28+ // ---------------------
29+ //
30+ // MMASingleSubgroupLayout describes the layout of one operand of one subgroup
31+ // level operation such as an MMA intrinsic.
32+ //
33+ // An MMA intrinsic, running on one thread, takes a 1D vector for each operand.
34+ // The purpose of MMASingleSubgroupLayout is to describe the mapping
35+ // from thread-id and 1D-vector-index into "semantical" dimensions (see next
36+ // paragraph). For example, for the accumulator ("C") operand of
37+ // @llvm.amdgcn.mfma.i32.32x32x8i8, which has type <16 x i32>, we want to know
38+ // for each of those 16 elements where it belongs in the MxN tile, in terms of
39+ // those "semantical" dimensions M and N, as a function of (1) the index of that
40+ // element within the <16 x i32> vector and (2) the thread-id.
41+ //
42+ // Here, by "semantical dimensions" we mean the high-level dimensions like the
43+ // M, N, K of matrix multiplication, or the M, N, K, Kb of scaled-matmul, or
44+ // the B, M, N, K of batch matmul. Some of these dimensions occur in the operand
45+ // that we are concerned with, some don't. For example, a matmul Lhs operand
46+ // only has the M and K semantical dimensions. A scaled-matmul Lhs has the M, K,
47+ // Kb semantical dimensions.
48+ //
49+ // Let us call "semantical rank" the number of semantical dimensions occuring
50+ // in the operand that we are concerned with. That is often 2, but can be 3
51+ // for some scaled-matmul operands and for all batch-matmul operands. This could
52+ // also be 1 for vector operands of matrix-vector operations.
53+ //
54+ // General invariants
55+ // ------------------
56+ //
57+ // Before we enter the more detailed description below, let us already state a
58+ // few high-level invariants.
59+ //
60+ // 0. All the member arrays have length equal to the semantical rank. A common
61+ // enumeration order of semantical dimensions is used throughout.
62+ // Each array entry corresponds to one semantical dimension. For example, for
63+ // a MFMA Lhs operand, the semantical dims are enumerated as M, K. Thus the
64+ // array elements outer[0], thread[0], element[0] correspond to the M
65+ // dimension, and the [1] correspond to the K dimension.
66+ // 1. For each semantical dimension, the product (outer[i] * thread[i] *
67+ // element[i]) equals the semantical dimension size, i.e., the tile size. For
68+ // example, in @llvm.amdgcn.mfma.i32.32x32x8i8, for the M and N dimensions,
69+ // these products equal 32.
70+ // 2. The product of all the outer[i] times all the element[i] equals the
71+ // length of the vector operand to the intrinsic. It is the number of
72+ // elements that one intrinsic consumes on one thread.
73+ // 3. The product of all the thread[i] is a divisor of subgroup size. It is
74+ // almost always equal to subgroup size. If not, then it is a strict divisor
75+ // of subgroup size and that means that multiple threads get the exact same
76+ // data, i.e., there is an implied broadcasting, as will be seen in the
77+ // modulo (t % thread [0]) below.
78+ //
79+ // Detailed semantics: case of semantic rank 1
80+ // -------------------------------------------
81+ //
82+ // When the semantic rank is 1, meaning that there is only 1 semantic dimension
83+ // (this would happen for a vector operand in a matrix-vector multiplication),
84+ // say "M", the mapping is as follows:
85+ //
86+ // /* Here t == thread_id, i == vector_element_index */
87+ // int map_vector_elem_index_to_semantic_dim_index(int t, int i) {
88+ // return (i % element[0]) + element[0] * (
89+ // (t % thread[0]) + thread[0] * (
90+ // i / element[0]
91+ // )
92+ // );
93+ // }
94+ //
95+ // Notice that we didn't use outer[0]. It is a redundant parameter in this case
96+ // since element[0] * outer[0] has to be the vector length as noted in the above
97+ // "invariants".
98+ //
99+ // Also notice that in the (rare) case where thread [0] is smaller than subgroup
100+ // size, multiple threads will get the same value of (t % thread [0]) and thus
101+ // will get the same data.
102+ //
103+ // Detailed semantics: general case
104+ // -------------------------------------------
105+ //
106+ // The general procedure is a generalization of the above rank-1 case:
107+ // 1. Delinearize the vector index modulo (element[0] * ... * element[rank - 1])
108+ // into the grid of shape {element[0], ..., element[rank - 1]}.
109+ // Thus, the element[] array describes the tile that is stored in contiguous
110+ // elements of the intrinsics' vector operand. We will call it the
111+ // "element tile". The layout within the element tile is always "row-major"
112+ // in the sense that the last-enumerated semantic dimension is the
113+ // most-contiguous dimension.
114+ // 2. Delinearize the thread-id modulo (thread[0] * ... * thread[rank - 1])
115+ // into the grid of shape {thread[0], ..., thread[rank - 1]}.
116+ // Thus, the different threads get different element tiles (from step 1)
117+ // except in the rare case that (thread[0] * ... * thread[rank - 1]) is less
118+ // than subgroup size, in which case the threads wrap around and share tiles.
119+ // * Unlike the element tiles from step 1, the distribution of these tiles to
120+ // threads is not necessarily following row-major order. The thread layout
121+ // is described by the tstrides. The meaning of tstrides[i] is: "as we move
122+ // by one element tile along semantic dimension i, we add tstrides[i]
123+ // to the thread_id". Note that in the rare case that multiple threads see
124+ // the same element tile, we can have tstrides[i] == 0.
125+ // 3. Delinearize the "outer vector index", defined as the quotient
126+ // vector_index / (element[0] * ... * element[rank - 1]),
127+ // into the grid of shape {outer[0], ..., outer[rank - 1]}.
128+ // Just like the element-tiles, the layout of these "outer tiles" is
129+ // row-major in the sense that the last-enumerated semantic dimension is the
130+ // most-contiguous dimension. The outer dimensions describe the arrangement
131+ // of element tiles in the overall vector operand of the intrinsic. This is
132+ // used only by a minority of intrinsics to describe "non-contiguous" operand
133+ // tiles.
134+ //
135+ // Example: @llvm.amdgcn.mfma.i32.32x32x8i8 accumulator
136+ // ----------------------------------------
137+ //
138+ // For the accumulator operand of @llvm.amdgcn.mfma.i32.32x32x8i8, we have:
139+ //
140+ // outer = {4, 1}
141+ // thread = {2, 32}
142+ // tstrides = {32, 1}
143+ // element = {4, 1}
144+ //
145+ // Let us first check the general invariants:
146+ // 0. The arrays all have length 2, corresponding to the semantic rank 2 of
147+ // matmul accumulators, where the 2 semantic dimensions are M and N.
148+ // 1. The semantic tile size is (M = 32, N = 32) and that does match the product
149+ // of the corresponding array elements outer[i] * thread[i] * element[i],
150+ // for instance 4 * 2 * 4 == 32.
151+ // 2. The product of all the outer[i] times all the element[i] is:
152+ // 4 * 1 * 4 * 1 == 16.
153+ // This corresponds to the @llvm.amdgcn.mfma.i32.32x32x8i8 intrinsic's vector
154+ // operand type, <16 x i32>.
155+ // 3. The product of the thread[i] is 2 * 32 == 64. This corresponds to the
156+ // subgroup size 64 on CDNA3. It could also have been a divisor of that, but
157+ // here the exact match means that each thread receives a different tile.
158+ //
159+ // The tstrides here are such that the distribution of element-tiles to threads
160+ // is "row-major": the fact that tstrides[1] == 1 means that as we move by
161+ // one element tile down the N-dimension, we move to the next thread by tid.
162+ //
163+ // Let us now detail the exact layout:
164+ // 1. Always start by looking at element[]. Here we see that the element tile
165+ // has shape 4x1.
166+ // 2. The thread-grid has shape 2x32, and as noted above has row-major thread
167+ // distribution. As what is being distributed to threads here is element
168+ // tiles of shape 4x1, at this point we have distributed an overall 8x32
169+ // tile.
170+ // 3. Finally, the outer[] completes the picture by saying that those 8x32
171+ // tiles are stacked vertically to form the overall 32x32 tile. The fact that
172+ // the vector length 16 is split into outer[0]==4 and element[0]==4 means
173+ // that these vectors contain groups of 4 matrix elements that are contiguous
174+ // along the M-dimension (the "element tiles"), but that there is a
175+ // discontinuity as the next 4 elements (the next "element tile") comes from
176+ // far away elsewhere in the C matrix, owing to the fact that thread[0] is
177+ // greater than 1. Example: thread 0 gets this accumulator ("C") operand:
178+ // { C[0, 0], C[1, 0], C[2, 0], C[3, 0],
179+ // C[8, 0], C[9, 0], C[10, 0], C[11, 0],
180+ // C[16, 0], C[17, 0], C[18, 0], C[19, 0],
181+ // C[24, 0], C[25, 0], C[26, 0], C[27, 0] }
182+ //
44183struct MMASingleSubgroupLayout {
45184 // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
46- // outer-most in the layout. This happens when a MMA op, seen on a single
185+ // outer-most in the layout. This happens when an MMA op, seen on a single
47186 // thread, has an operand that consists of multiple elements, and these elems
48187 // are NOT contiguous.
49188 // This is not used by every MMA op; ops which don't use that simply have 1's.
@@ -56,7 +195,7 @@ struct MMASingleSubgroupLayout {
56195 // Strides corresponding to the cross-thread dimensions.
57196 SmallVector<int64_t , 2 > tstrides;
58197 // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
59- // inner-most in the layout. This happens when a MMA op, seen on a single
198+ // inner-most in the layout. This happens when an MMA op, seen on a single
60199 // thread, has an operand that consists of multiple elements, and these elems
61200 // are contiguous.
62201 // This is not used by every MMA op; ops which don't use that simply have 1's.
0 commit comments