@@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
113113// row0 reg[0-1] reg[4-5]
114114// row8 reg[2-3] reg[6-7]
115115//
116+ // When `swizzleByteSize` is non-zero, the layout is constructed
117+ // differently due to leading dimension offset and swizzling.
118+ // There are two key concepts to understand:
119+ //
120+ // 1. Chunks: The leading dimension (i.e., the column dimension) is divided
121+ // into chunks, where each chunk's size is determined by `swizzleByteSize`.
122+ // 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
123+ // rows to optimize memory access.
124+ //
125+ // - Concept 1: Chunks
126+ //
127+ // In the swizzled layout, the leading dimension is strided by
128+ // `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
129+ // spans a certain number of columns.
130+ //
131+ // For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
132+ // bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
133+ // elements * 2 bytes per element = 32 bytes per row).
134+ //
135+ // Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
136+ // calculated as:
137+ //
138+ // Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
139+ // 32 bytes = 4 tiles
140+ //
141+ // Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
142+ // (since each tile is 16 columns):
143+ //
144+ // col0-15 col16-31 col32-47 col48-63
145+ // row0-15 tile0 tile1 tile2 tile3
146+ //
147+ // For a tensor of size 128x128 elements (#rows x #columns), and each element
148+ // being 16 bits, the tensor can be divided into multiple chunks both
149+ // horizontally and vertically. Chunks are stored in memory in a "column-major"
150+ // order based on chunks, meaning chunk1's address follows chunk0's.
151+ //
152+ // Assuming we have 8 warps, and we assign each warp to process a chunk of 16
153+ // rows (rows per tile) and 128 columns (the width of two chunks). This results
154+ // in each warp handling one horizontal slice of the tensor.
155+ //
156+ // The overall layout can be visualized as:
157+ //
158+ // |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
159+ // columns 0-63 columns 64-127
160+ // warp0 | rows 0-15 chunk0 chunk8
161+ // warp1 | rows 16-31 chunk1 chunk9
162+ // warp2 | rows 32-47 chunk2 chunk10
163+ // warp3 | rows 48-63 chunk3 chunk11
164+ // warp4 | rows 64-79 chunk4 chunk12
165+ // warp5 | rows 80-95 chunk5 chunk13
166+ // warp6 | rows 96-111 chunk6 chunk14
167+ // warp7 | rows 112-127 chunk7 chunk15
168+ //
169+ // - Concept 2: Swizzling within tiles
170+ //
171+ // Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
172+ // This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
173+ // level of each 16x16 tile rather than the entire tensor.
174+ //
175+ // Key parameters for swizzling:
176+ //
177+ // - `perPhase`: The number of rows over which to apply a XOR operation at
178+ // each phase.
179+ // - `maxPhase`: The total number of phases.
180+ // - `vectorWidth`: The number of elements per vector, which is 8 in this case
181+ // because `stmatrix` stores 8 contiguous elements per thread.
182+ //
183+ // The offset of each element within a tile is calculated using the formula:
184+ //
185+ // offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
186+ // maxPhase)) * elementSize
187+ //
188+ // where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
189+ // elements).
190+ //
191+ // For example, consider the element at index `(row=1, col=0)` in chunk0:
192+ //
193+ // Without swizzling:
194+ //
195+ // offset = row * swizzleByteSize + col * elementSize
196+ // = 1 * 128 bytes + 0 * 2 bytes
197+ // = 128 bytes
198+ //
199+ // With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
200+ //
201+ // offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
202+ // maxPhase)) * elementSize
203+ // = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
204+ // = 128 bytes + (8 * (1 % 8)) * 2 bytes
205+ // = 128 bytes + 8 * 2 bytes
206+ // = 128 bytes + 16 bytes
207+ // = 144 bytes
208+ //
209+ // This swizzling ensures that elements are stored in a way that optimizes for
210+ // memory bandwidth and reduces bank conflicts.
211+ //
212+ // - Verification through Linear Layout
213+ //
214+ // We can verify the offsets with the following outputs of the corresponding
215+ // linear layout, where each element is 16 bits (2 bytes):
216+ //
217+ // - register=1 -> offset=1
218+ // register=2 -> offset=2
219+ // register=4 -> offset=4
220+ // register=8 -> offset=16
221+ // register=16 -> offset=32
222+ // register=32 -> offset=8192
223+ // - lane=1 -> offset=72
224+ // lane=2 -> offset=144
225+ // lane=4 -> offset=288
226+ // lane=8 -> offset=512
227+ // lane=16 -> offset=8
228+ // - warp=1 -> offset=1024
229+ // warp=2 -> offset=2048
230+ // warp=4 -> offset=4096
231+ //
232+ // For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
233+ // `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
234+ // matches our earlier calculation.
235+ //
116236// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
117237// bit width of the tensor in the future to support more flexible tensor
118238// encodings
119- std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion (
120- MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
121- ArrayRef<unsigned > paddedRepShape, ArrayRef<unsigned > order);
239+ std::optional<LinearLayout>
240+ chooseStMatrixLayout (MLIRContext *ctx, RankedTensorType tensorTy,
241+ ArrayRef<unsigned > repShape,
242+ ArrayRef<unsigned > paddedRepShape,
243+ ArrayRef<unsigned > order, int swizzleByteSize);
122244} // namespace mlir::triton::gpu
123245
124246#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
0 commit comments