Skip to content

Commit 0e46977

Browse files
authored
Rewrite SingleSubgroupLayout documentation (iree-org#22412)
Signed-off-by: Benoit Jacob <[email protected]>
1 parent b4ba1e8 commit 0e46977

File tree

1 file changed

+162
-23
lines changed

1 file changed

+162
-23
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 162 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,169 @@
2020

2121
namespace mlir::iree_compiler::IREE::GPU {
2222

23-
// Struct describing the detailed subgroup-level layout of a MMA intrinsic.
24-
// Together with element type information and subgroup size, it completes the
25-
// full description of the semantics of a MMA intrinsic.
26-
//
27-
// All of these fields hold two values corresponding to the two dimensions of a
28-
// matrix in order. That is, a SingleSubgroupLayout for MmaFragment::LHS (matrix
29-
// A) holds the values for the M and K dimension in indices 0 and 1 of each
30-
// component, respectively, the RHS fragment is K x N, and the Acc fragment (for
31-
// accululators) is M x N.
32-
//
33-
// Note: It is not possible to infer subgroup size from the information in this
34-
// struct. The product of the `thread` sizes here is often, but not always equal
35-
// to subgroup size. When the product of the `thread` sizes (call that product
36-
// `P`) is smaller than subgroup size, it must be a divisor of it, and the
37-
// semantics in that case are that threads within the subgroup whose thread-ids
38-
// differ by a multiple of `P`, are accessing the same elements.
39-
//
40-
// Example observed in RDNA3 WMMA Wave32 intrinsics:
41-
// If the subgroup size is 32 but the product `P` of `thread` sizes is 16, that
42-
// means that each element is being accessed by 2 threads (2 = 32/16), and the
43-
// threads accessing the same element are those whose tids are exactly 16 apart.
23+
//////////////////////////////////////////////////////////////////////////////
24+
// MMASingleSubgroupLayout
25+
//////////////////////////////////////////////////////////////////////////////
26+
//
27+
// Overview, terminology
28+
// ---------------------
29+
//
30+
// MMASingleSubgroupLayout describes the layout of one operand of one subgroup
31+
// level operation such as an MMA intrinsic.
32+
//
33+
// An MMA intrinsic, running on one thread, takes a 1D vector for each operand.
34+
// The purpose of MMASingleSubgroupLayout is to describe the mapping
35+
// from thread-id and 1D-vector-index into "semantical" dimensions (see next
36+
// paragraph). For example, for the accumulator ("C") operand of
37+
// @llvm.amdgcn.mfma.i32.32x32x8i8, which has type <16 x i32>, we want to know
38+
// for each of those 16 elements where it belongs in the MxN tile, in terms of
39+
// those "semantical" dimensions M and N, as a function of (1) the index of that
40+
// element within the <16 x i32> vector and (2) the thread-id.
41+
//
42+
// Here, by "semantical dimensions" we mean the high-level dimensions like the
43+
// M, N, K of matrix multiplication, or the M, N, K, Kb of scaled-matmul, or
44+
// the B, M, N, K of batch matmul. Some of these dimensions occur in the operand
45+
// that we are concerned with, some don't. For example, a matmul Lhs operand
46+
// only has the M and K semantical dimensions. A scaled-matmul Lhs has the M, K,
47+
// Kb semantical dimensions.
48+
//
49+
// Let us call "semantical rank" the number of semantical dimensions occuring
50+
// in the operand that we are concerned with. That is often 2, but can be 3
51+
// for some scaled-matmul operands and for all batch-matmul operands. This could
52+
// also be 1 for vector operands of matrix-vector operations.
53+
//
54+
// General invariants
55+
// ------------------
56+
//
57+
// Before we enter the more detailed description below, let us already state a
58+
// few high-level invariants.
59+
//
60+
// 0. All the member arrays have length equal to the semantical rank. A common
61+
// enumeration order of semantical dimensions is used throughout.
62+
// Each array entry corresponds to one semantical dimension. For example, for
63+
// a MFMA Lhs operand, the semantical dims are enumerated as M, K. Thus the
64+
// array elements outer[0], thread[0], element[0] correspond to the M
65+
// dimension, and the [1] correspond to the K dimension.
66+
// 1. For each semantical dimension, the product (outer[i] * thread[i] *
67+
// element[i]) equals the semantical dimension size, i.e., the tile size. For
68+
// example, in @llvm.amdgcn.mfma.i32.32x32x8i8, for the M and N dimensions,
69+
// these products equal 32.
70+
// 2. The product of all the outer[i] times all the element[i] equals the
71+
// length of the vector operand to the intrinsic. It is the number of
72+
// elements that one intrinsic consumes on one thread.
73+
// 3. The product of all the thread[i] is a divisor of subgroup size. It is
74+
// almost always equal to subgroup size. If not, then it is a strict divisor
75+
// of subgroup size and that means that multiple threads get the exact same
76+
// data, i.e., there is an implied broadcasting, as will be seen in the
77+
// modulo (t % thread [0]) below.
78+
//
79+
// Detailed semantics: case of semantic rank 1
80+
// -------------------------------------------
81+
//
82+
// When the semantic rank is 1, meaning that there is only 1 semantic dimension
83+
// (this would happen for a vector operand in a matrix-vector multiplication),
84+
// say "M", the mapping is as follows:
85+
//
86+
// /* Here t == thread_id, i == vector_element_index */
87+
// int map_vector_elem_index_to_semantic_dim_index(int t, int i) {
88+
// return (i % element[0]) + element[0] * (
89+
// (t % thread[0]) + thread[0] * (
90+
// i / element[0]
91+
// )
92+
// );
93+
// }
94+
//
95+
// Notice that we didn't use outer[0]. It is a redundant parameter in this case
96+
// since element[0] * outer[0] has to be the vector length as noted in the above
97+
// "invariants".
98+
//
99+
// Also notice that in the (rare) case where thread [0] is smaller than subgroup
100+
// size, multiple threads will get the same value of (t % thread [0]) and thus
101+
// will get the same data.
102+
//
103+
// Detailed semantics: general case
104+
// -------------------------------------------
105+
//
106+
// The general procedure is a generalization of the above rank-1 case:
107+
// 1. Delinearize the vector index modulo (element[0] * ... * element[rank - 1])
108+
// into the grid of shape {element[0], ..., element[rank - 1]}.
109+
// Thus, the element[] array describes the tile that is stored in contiguous
110+
// elements of the intrinsics' vector operand. We will call it the
111+
// "element tile". The layout within the element tile is always "row-major"
112+
// in the sense that the last-enumerated semantic dimension is the
113+
// most-contiguous dimension.
114+
// 2. Delinearize the thread-id modulo (thread[0] * ... * thread[rank - 1])
115+
// into the grid of shape {thread[0], ..., thread[rank - 1]}.
116+
// Thus, the different threads get different element tiles (from step 1)
117+
// except in the rare case that (thread[0] * ... * thread[rank - 1]) is less
118+
// than subgroup size, in which case the threads wrap around and share tiles.
119+
// * Unlike the element tiles from step 1, the distribution of these tiles to
120+
// threads is not necessarily following row-major order. The thread layout
121+
// is described by the tstrides. The meaning of tstrides[i] is: "as we move
122+
// by one element tile along semantic dimension i, we add tstrides[i]
123+
// to the thread_id". Note that in the rare case that multiple threads see
124+
// the same element tile, we can have tstrides[i] == 0.
125+
// 3. Delinearize the "outer vector index", defined as the quotient
126+
// vector_index / (element[0] * ... * element[rank - 1]),
127+
// into the grid of shape {outer[0], ..., outer[rank - 1]}.
128+
// Just like the element-tiles, the layout of these "outer tiles" is
129+
// row-major in the sense that the last-enumerated semantic dimension is the
130+
// most-contiguous dimension. The outer dimensions describe the arrangement
131+
// of element tiles in the overall vector operand of the intrinsic. This is
132+
// used only by a minority of intrinsics to describe "non-contiguous" operand
133+
// tiles.
134+
//
135+
// Example: @llvm.amdgcn.mfma.i32.32x32x8i8 accumulator
136+
// ----------------------------------------
137+
//
138+
// For the accumulator operand of @llvm.amdgcn.mfma.i32.32x32x8i8, we have:
139+
//
140+
// outer = {4, 1}
141+
// thread = {2, 32}
142+
// tstrides = {32, 1}
143+
// element = {4, 1}
144+
//
145+
// Let us first check the general invariants:
146+
// 0. The arrays all have length 2, corresponding to the semantic rank 2 of
147+
// matmul accumulators, where the 2 semantic dimensions are M and N.
148+
// 1. The semantic tile size is (M = 32, N = 32) and that does match the product
149+
// of the corresponding array elements outer[i] * thread[i] * element[i],
150+
// for instance 4 * 2 * 4 == 32.
151+
// 2. The product of all the outer[i] times all the element[i] is:
152+
// 4 * 1 * 4 * 1 == 16.
153+
// This corresponds to the @llvm.amdgcn.mfma.i32.32x32x8i8 intrinsic's vector
154+
// operand type, <16 x i32>.
155+
// 3. The product of the thread[i] is 2 * 32 == 64. This corresponds to the
156+
// subgroup size 64 on CDNA3. It could also have been a divisor of that, but
157+
// here the exact match means that each thread receives a different tile.
158+
//
159+
// The tstrides here are such that the distribution of element-tiles to threads
160+
// is "row-major": the fact that tstrides[1] == 1 means that as we move by
161+
// one element tile down the N-dimension, we move to the next thread by tid.
162+
//
163+
// Let us now detail the exact layout:
164+
// 1. Always start by looking at element[]. Here we see that the element tile
165+
// has shape 4x1.
166+
// 2. The thread-grid has shape 2x32, and as noted above has row-major thread
167+
// distribution. As what is being distributed to threads here is element
168+
// tiles of shape 4x1, at this point we have distributed an overall 8x32
169+
// tile.
170+
// 3. Finally, the outer[] completes the picture by saying that those 8x32
171+
// tiles are stacked vertically to form the overall 32x32 tile. The fact that
172+
// the vector length 16 is split into outer[0]==4 and element[0]==4 means
173+
// that these vectors contain groups of 4 matrix elements that are contiguous
174+
// along the M-dimension (the "element tiles"), but that there is a
175+
// discontinuity as the next 4 elements (the next "element tile") comes from
176+
// far away elsewhere in the C matrix, owing to the fact that thread[0] is
177+
// greater than 1. Example: thread 0 gets this accumulator ("C") operand:
178+
// { C[0, 0], C[1, 0], C[2, 0], C[3, 0],
179+
// C[8, 0], C[9, 0], C[10, 0], C[11, 0],
180+
// C[16, 0], C[17, 0], C[18, 0], C[19, 0],
181+
// C[24, 0], C[25, 0], C[26, 0], C[27, 0] }
182+
//
44183
struct MMASingleSubgroupLayout {
45184
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
46-
// outer-most in the layout. This happens when a MMA op, seen on a single
185+
// outer-most in the layout. This happens when an MMA op, seen on a single
47186
// thread, has an operand that consists of multiple elements, and these elems
48187
// are NOT contiguous.
49188
// This is not used by every MMA op; ops which don't use that simply have 1's.
@@ -56,7 +195,7 @@ struct MMASingleSubgroupLayout {
56195
// Strides corresponding to the cross-thread dimensions.
57196
SmallVector<int64_t, 2> tstrides;
58197
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
59-
// inner-most in the layout. This happens when a MMA op, seen on a single
198+
// inner-most in the layout. This happens when an MMA op, seen on a single
60199
// thread, has an operand that consists of multiple elements, and these elems
61200
// are contiguous.
62201
// This is not used by every MMA op; ops which don't use that simply have 1's.

0 commit comments

Comments
 (0)