Skip to content

Commit 49266aa

Browse files
authored
[BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc ops (#4763)
This PR enables the use of `stmatrix` for `local_alloc` ops through linear layout and removes the legacy code from the `TargetInfo` class.
1 parent 80a5cfb commit 49266aa

File tree

10 files changed

+287
-215
lines changed

10 files changed

+287
-215
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ class TargetInfoBase {
5757
unsigned numLaneToReduce,
5858
unsigned interleave) const = 0;
5959

60-
// TODO (Keren): Remove this function once layout conversion using stmatrix is
61-
// handled by Linear Layout.
62-
virtual bool processReplicaUsingStMatrix(
63-
RewriterBase &rewriter, Location loc, Value smemBase,
64-
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
65-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
66-
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
67-
int swizzleByteWidth = 0) const = 0;
68-
6960
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7061
// Emits LLVM code with |rewriter| to print a message following the given
7162
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
113113
// row0 reg[0-1] reg[4-5]
114114
// row8 reg[2-3] reg[6-7]
115115
//
116+
// When `swizzleByteSize` is non-zero, the layout is constructed
117+
// differently due to leading dimension offset and swizzling.
118+
// There are two key concepts to understand:
119+
//
120+
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
121+
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
122+
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
123+
// rows to optimize memory access.
124+
//
125+
// - Concept 1: Chunks
126+
//
127+
// In the swizzled layout, the leading dimension is strided by
128+
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
129+
// spans a certain number of columns.
130+
//
131+
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
132+
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
133+
// elements * 2 bytes per element = 32 bytes per row).
134+
//
135+
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
136+
// calculated as:
137+
//
138+
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
139+
// 32 bytes = 4 tiles
140+
//
141+
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
142+
// (since each tile is 16 columns):
143+
//
144+
// col0-15 col16-31 col32-47 col48-63
145+
// row0-15 tile0 tile1 tile2 tile3
146+
//
147+
// For a tensor of size 128x128 elements (#rows x #columns), and each element
148+
// being 16 bits, the tensor can be divided into multiple chunks both
149+
// horizontally and vertically. Chunks are stored in memory in a "column-major"
150+
// order based on chunks, meaning chunk1's address follows chunk0's.
151+
//
152+
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
153+
// rows (rows per tile) and 128 columns (the width of two chunks). This results
154+
// in each warp handling one horizontal slice of the tensor.
155+
//
156+
// The overall layout can be visualized as:
157+
//
158+
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
159+
// columns 0-63 columns 64-127
160+
// warp0 | rows 0-15 chunk0 chunk8
161+
// warp1 | rows 16-31 chunk1 chunk9
162+
// warp2 | rows 32-47 chunk2 chunk10
163+
// warp3 | rows 48-63 chunk3 chunk11
164+
// warp4 | rows 64-79 chunk4 chunk12
165+
// warp5 | rows 80-95 chunk5 chunk13
166+
// warp6 | rows 96-111 chunk6 chunk14
167+
// warp7 | rows 112-127 chunk7 chunk15
168+
//
169+
// - Concept 2: Swizzling within tiles
170+
//
171+
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
172+
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
173+
// level of each 16x16 tile rather than the entire tensor.
174+
//
175+
// Key parameters for swizzling:
176+
//
177+
// - `perPhase`: The number of rows over which to apply a XOR operation at
178+
// each phase.
179+
// - `maxPhase`: The total number of phases.
180+
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
181+
// because `stmatrix` stores 8 contiguous elements per thread.
182+
//
183+
// The offset of each element within a tile is calculated using the formula:
184+
//
185+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
186+
// maxPhase)) * elementSize
187+
//
188+
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
189+
// elements).
190+
//
191+
// For example, consider the element at index `(row=1, col=0)` in chunk0:
192+
//
193+
// Without swizzling:
194+
//
195+
// offset = row * swizzleByteSize + col * elementSize
196+
// = 1 * 128 bytes + 0 * 2 bytes
197+
// = 128 bytes
198+
//
199+
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
200+
//
201+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
202+
// maxPhase)) * elementSize
203+
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
204+
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
205+
// = 128 bytes + 8 * 2 bytes
206+
// = 128 bytes + 16 bytes
207+
// = 144 bytes
208+
//
209+
// This swizzling ensures that elements are stored in a way that optimizes for
210+
// memory bandwidth and reduces bank conflicts.
211+
//
212+
// - Verification through Linear Layout
213+
//
214+
// We can verify the offsets with the following outputs of the corresponding
215+
// linear layout, where each element is 16 bits (2 bytes):
216+
//
217+
// - register=1 -> offset=1
218+
// register=2 -> offset=2
219+
// register=4 -> offset=4
220+
// register=8 -> offset=16
221+
// register=16 -> offset=32
222+
// register=32 -> offset=8192
223+
// - lane=1 -> offset=72
224+
// lane=2 -> offset=144
225+
// lane=4 -> offset=288
226+
// lane=8 -> offset=512
227+
// lane=16 -> offset=8
228+
// - warp=1 -> offset=1024
229+
// warp=2 -> offset=2048
230+
// warp=4 -> offset=4096
231+
//
232+
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
233+
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
234+
// matches our earlier calculation.
235+
//
116236
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
117237
// bit width of the tensor in the future to support more flexible tensor
118238
// encodings
119-
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
120-
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
121-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
239+
std::optional<LinearLayout>
240+
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
241+
ArrayRef<unsigned> repShape,
242+
ArrayRef<unsigned> paddedRepShape,
243+
ArrayRef<unsigned> order, int swizzleByteSize);
122244
} // namespace mlir::triton::gpu
123245

124246
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
215215
if (repId != 0) {
216216
barrier();
217217
}
218-
auto successful = targetInfo.processReplicaUsingStMatrix(
219-
rewriter, loc, smemBase, vals, srcTy,
220-
getTypeConverter()->convertType(srcTy.getElementType()),
221-
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
222-
if (!successful) {
223-
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
224-
multiDimRepId, inVec, paddedRepShape, origRepShape,
225-
outOrd, vals, smemBase);
226-
}
218+
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
219+
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
220+
vals, smemBase);
227221
barrier();
228222
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
229223
multiDimRepId, outVec, paddedRepShape, origRepShape,
@@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
483477
// Input dims: [reg, lane, warp]
484478
// Output dims: [offset, iteration]
485479
std::optional<LinearLayout> shmemStoreLayout =
486-
chooseStMatrixLayoutForRegToRegConversion(
487-
ctx, op.getSrc().getType(), scratchConfig.repShape,
488-
scratchConfig.paddedRepShape, scratchConfig.order);
480+
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
481+
scratchConfig.paddedRepShape, scratchConfig.order,
482+
/*swizzleByteSize=*/0);
489483
bool isStMatrix = shmemStoreLayout.has_value();
490484
if (!isStMatrix) {
491485
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119-
// TODO: do we need to check if src is shared ?
120119
if (isa<SharedEncodingAttr>(srcLayout) &&
121120
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
122121
dstLayout)) {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,8 @@ namespace {
820820
// stmatrix. These restrictions are retained from legacy code, and we could
821821
// relax some of them in the future.
822822
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
823-
ArrayRef<unsigned> paddedRepShape,
824-
ArrayRef<unsigned> order) {
823+
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
824+
int swizzleByteSize) {
825825
auto mmaLayout =
826826
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
827827
if (!mmaLayout || !mmaLayout.isHopper())
@@ -840,17 +840,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
840840
return false;
841841
if (paddedRepShape[1] % 8 != 0)
842842
return false;
843+
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
844+
swizzleByteSize != 128)
845+
return false;
843846
return true;
844847
}
845848

846-
} // anonymous namespace
849+
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
850+
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
851+
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
852+
int swizzleByteSize) {
853+
StringAttr kReg = S("register");
854+
StringAttr kLane = S("lane");
855+
StringAttr kWarp = S("warp");
856+
StringAttr kCol = S("dim1");
857+
StringAttr kRow = S("dim0");
858+
StringAttr kOffset = S("offset");
859+
860+
int perPhase;
861+
int maxPhase;
862+
if (swizzleByteSize == 32) {
863+
perPhase = 4;
864+
maxPhase = 2;
865+
} else if (swizzleByteSize == 64) {
866+
perPhase = 2;
867+
maxPhase = 4;
868+
} else if (swizzleByteSize == 128) {
869+
perPhase = 1;
870+
maxPhase = 8;
871+
} else {
872+
llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n";
873+
llvm::report_fatal_error("Illegal swizzleByteSize");
874+
}
875+
876+
// stmatrix only supports 16-bit elements, and each vector has 8 elements
877+
int elemBitWidth = 16;
878+
int vecSize = 8;
879+
int numRows = 16;
880+
int numCols = 8 * swizzleByteSize / elemBitWidth;
881+
882+
// Construct a single stmatrix.x4 (16x16) tile
883+
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
884+
std::vector<std::vector<int>> basesLane;
885+
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
886+
int row = 1 << logRow;
887+
basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row});
888+
}
889+
basesLane.push_back({8, 0});
890+
891+
// Expand the tile's register dimension to fit swizzleByteSize, which is a
892+
// "chunk"
893+
for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) {
894+
int chunk = 1 << logChunk;
895+
basesReg.push_back({16 * chunk, 0});
896+
}
897+
898+
// Construct the layout for a single chunk
899+
LinearLayout layout =
900+
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});
847901

848-
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
902+
// Expand the `warp` dimension according to warpsPerCTA.
903+
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
904+
layout *=
905+
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
906+
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
907+
908+
// Expand the `register` dimension so the size of columns matches `n`.
909+
int n = mma.getInstrShape()[1];
910+
int numWarpRows = layout.getOutDimSize(kRow);
911+
layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) *
912+
LinearLayout::identity1D(n / numCols, kReg, kOffset))
913+
.reshapeOuts({{kCol, n}, {kRow, numWarpRows}});
914+
915+
auto ret =
916+
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
917+
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
918+
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
919+
}
920+
921+
std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
849922
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
850923
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
851-
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order))
852-
return std::nullopt;
853-
854924
StringAttr kReg = S("register");
855925
StringAttr kLane = S("lane");
856926
StringAttr kWarp = S("warp");
@@ -880,4 +950,23 @@ std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
880950
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
881951
}
882952

953+
} // anonymous namespace
954+
955+
std::optional<LinearLayout>
956+
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
957+
ArrayRef<unsigned> repShape,
958+
ArrayRef<unsigned> paddedRepShape,
959+
ArrayRef<unsigned> order, int swizzleByteSize) {
960+
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
961+
swizzleByteSize))
962+
return std::nullopt;
963+
964+
if (swizzleByteSize == 0)
965+
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
966+
paddedRepShape, order);
967+
else
968+
return chooseStMatrixLayoutLeadingOffset(
969+
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
970+
}
971+
883972
} // namespace mlir::triton::gpu

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,6 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
136136
return false;
137137
}
138138

139-
bool TargetInfo::processReplicaUsingStMatrix(
140-
RewriterBase &rewriter, Location loc, Value smemBase,
141-
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
142-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
143-
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
144-
int swizzleByteWidth) const {
145-
return false;
146-
}
147-
148139
void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,
149140
ValueRange args, RewriterBase &rewriter,
150141
bool useStdErr) const {

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
4646
triton::ReduceOp op, unsigned numLaneToReduce,
4747
unsigned interleave) const override;
4848

49-
bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc,
50-
Value smemBase, SmallVector<Value> &vals,
51-
RankedTensorType srcTy, Type elemTy,
52-
ArrayRef<unsigned> paddedRepShape,
53-
ArrayRef<unsigned> origRepShape,
54-
ArrayRef<unsigned> outOrd,
55-
unsigned accumNumReplicates,
56-
int swizzleByteWidth) const override;
57-
5849
std::string getMulhiFuncName(Type resultElementTy) const override;
5950

6051
void printf(RewriterBase &rewriter, Value formatStrStart,

0 commit comments

Comments
 (0)