@@ -99,174 +99,6 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
99
99
MLIRContext *ctx, ArrayRef<unsigned > tensorShape,
100
100
ArrayRef<unsigned > repShape, ArrayRef<unsigned > order);
101
101
102
- // This function constructs a linear layout that maps
103
- // <register, lane, warp> to <shared memory offset, iteration>.
104
- // The primary goal is to efficiently store 2D tiles of a tensor into shared
105
- // memory using the `stmatrix` instruction, with each thread responsible for
106
- // storing `N` elements. If `stmatrix` cannot be used for the given tensor
107
- // encoding, this function returns `std::nullopt`.
108
- //
109
- // Unlike standard vectorized stores, such as `st.shared.v4 [%offset],
110
- // %vec_reg`, where `%vec_reg` contains four consecutive data elements, the
111
- // `stmatrix` instruction allows `N` registers to point to non-contiguous
112
- // locations within a tensor tile.
113
- //
114
- // For instance, the `stmatrix [%offset], %mat_reg` instruction on NVIDIA GPUs
115
- // enables `%mat_reg` to store `N` elements that do not need to be consecutive.
116
- // However, it is crucial that the address (`%offset`) of each row in a tensor
117
- // tile should be aligned to `N` * `elemBitWidth`. The `%offset` of each thread
118
- // is calculated based on the provided tensor encoding.
119
- //
120
- // Currently, we support only the NVIDIA MMAv3 encoding and the `stmatrix.x4`
121
- // instruction. Each `stmatrix.x4` instruction stores eight 16-bit elements per
122
- // thread, resulting in a total of 8 * 32 = 256 elements per warp, or 16 * 16
123
- // elements per warp when distributed across four 8x8 tiles. Each thread's
124
- // `%offset` points to an address aligned with 8 * 16 bits, denoting a row in
125
- // the 8x8 tile. The values in `%mat_reg` are non-consecutive elements,
126
- // composed of 4 pairs of consecutive elements. These matrix addresses are
127
- // distributed as follows:
128
- //
129
- // col[0-7] col[8-15]
130
- // row[0-7] lane[0-7] lane[16-23]
131
- // row[8-15] lane[8-15] lane[24-31]
132
- //
133
- // The matrix elements of thread 0 are distributed in the following pattern:
134
- //
135
- // col0 col8
136
- // row0 reg[0-1] reg[4-5]
137
- // row8 reg[2-3] reg[6-7]
138
- //
139
- // When `swizzleByteSize` is non-zero, the layout is constructed
140
- // differently due to leading dimension offset and swizzling.
141
- // There are two key concepts to understand:
142
- //
143
- // 1. Chunks: The leading dimension (i.e., the column dimension) is divided
144
- // into chunks, where each chunk's size is determined by `swizzleByteSize`.
145
- // 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
146
- // rows to optimize memory access.
147
- //
148
- // - Concept 1: Chunks
149
- //
150
- // In the swizzled layout, the leading dimension is strided by
151
- // `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
152
- // spans a certain number of columns.
153
- //
154
- // For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
155
- // bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
156
- // elements * 2 bytes per element = 32 bytes per row).
157
- //
158
- // Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
159
- // calculated as:
160
- //
161
- // Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
162
- // 32 bytes = 4 tiles
163
- //
164
- // Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
165
- // (since each tile is 16 columns):
166
- //
167
- // col0-15 col16-31 col32-47 col48-63
168
- // row0-15 tile0 tile1 tile2 tile3
169
- //
170
- // For a tensor of size 128x128 elements (#rows x #columns), and each element
171
- // being 16 bits, the tensor can be divided into multiple chunks both
172
- // horizontally and vertically. Chunks are stored in memory in a "column-major"
173
- // order based on chunks, meaning chunk1's address follows chunk0's.
174
- //
175
- // Assuming we have 8 warps, and we assign each warp to process a chunk of 16
176
- // rows (rows per tile) and 128 columns (the width of two chunks). This results
177
- // in each warp handling one horizontal slice of the tensor.
178
- //
179
- // The overall layout can be visualized as:
180
- //
181
- // |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
182
- // columns 0-63 columns 64-127
183
- // warp0 | rows 0-15 chunk0 chunk8
184
- // warp1 | rows 16-31 chunk1 chunk9
185
- // warp2 | rows 32-47 chunk2 chunk10
186
- // warp3 | rows 48-63 chunk3 chunk11
187
- // warp4 | rows 64-79 chunk4 chunk12
188
- // warp5 | rows 80-95 chunk5 chunk13
189
- // warp6 | rows 96-111 chunk6 chunk14
190
- // warp7 | rows 112-127 chunk7 chunk15
191
- //
192
- // - Concept 2: Swizzling within tiles
193
- //
194
- // Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
195
- // This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
196
- // level of each 16x16 tile rather than the entire tensor.
197
- //
198
- // Key parameters for swizzling:
199
- //
200
- // - `perPhase`: The number of rows over which to apply a XOR operation at
201
- // each phase.
202
- // - `maxPhase`: The total number of phases.
203
- // - `vectorWidth`: The number of elements per vector, which is 8 in this case
204
- // because `stmatrix` stores 8 contiguous elements per thread.
205
- //
206
- // The offset of each element within a tile is calculated using the formula:
207
- //
208
- // offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
209
- // maxPhase)) * elementSize
210
- //
211
- // where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
212
- // elements).
213
- //
214
- // For example, consider the element at index `(row=1, col=0)` in chunk0:
215
- //
216
- // Without swizzling:
217
- //
218
- // offset = row * swizzleByteSize + col * elementSize
219
- // = 1 * 128 bytes + 0 * 2 bytes
220
- // = 128 bytes
221
- //
222
- // With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
223
- //
224
- // offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
225
- // maxPhase)) * elementSize
226
- // = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
227
- // = 128 bytes + (8 * (1 % 8)) * 2 bytes
228
- // = 128 bytes + 8 * 2 bytes
229
- // = 128 bytes + 16 bytes
230
- // = 144 bytes
231
- //
232
- // This swizzling ensures that elements are stored in a way that optimizes for
233
- // memory bandwidth and reduces bank conflicts.
234
- //
235
- // - Verification through Linear Layout
236
- //
237
- // We can verify the offsets with the following outputs of the corresponding
238
- // linear layout, where each element is 16 bits (2 bytes):
239
- //
240
- // - register=1 -> offset=1
241
- // register=2 -> offset=2
242
- // register=4 -> offset=4
243
- // register=8 -> offset=16
244
- // register=16 -> offset=32
245
- // register=32 -> offset=8192
246
- // - lane=1 -> offset=72
247
- // lane=2 -> offset=144
248
- // lane=4 -> offset=288
249
- // lane=8 -> offset=512
250
- // lane=16 -> offset=8
251
- // - warp=1 -> offset=1024
252
- // warp=2 -> offset=2048
253
- // warp=4 -> offset=4096
254
- //
255
- // For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
256
- // `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
257
- // matches our earlier calculation.
258
- //
259
- // TODO(Keren): We should replace tensorTy with a LinearLayout and the element
260
- // bit width of the tensor in the future to support more flexible tensor
261
- // encodings
262
- LinearLayout chooseStMatrixLayout (MLIRContext *ctx, RankedTensorType tensorTy,
263
- int swizzleByteSize);
264
-
265
- // The primary goal of this function is to efficiently store 2D tiles of a
266
- // tensor into shared memory using the `ldmatrix` instruction.
267
- LinearLayout chooseLdMatrixLayout (Attribute enc, ArrayRef<int64_t > shape,
268
- bool needTrans, int32_t elemBitWidth);
269
-
270
102
// The primary goal of this function is to efficiently load 2D tiles of a
271
103
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
272
104
LinearLayout chooseDsReadB64TrLayout (Attribute enc, ArrayRef<int64_t > shape,
0 commit comments