Skip to content

Commit 17dac54

Browse files
Merge commit '6c3e9535c44774dfd56357acba9c2183b247f58e'
2 parents d344dd3 + 6c3e953 commit 17dac54

File tree

63 files changed

+1607
-474
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1607
-474
lines changed

.github/workflows/integration-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ env:
2525
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
2626
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
2727
TRITON_DISABLE_LINE_INFO: 1
28+
PROTON_SKIP_PC_SAMPLING_TEST: 1
2829
jobs:
2930
Runner-Preparation:
3031
runs-on: ubuntu-latest
@@ -460,7 +461,7 @@ jobs:
460461
- name: Install brew dependencies
461462
run: |
462463
brew update
463-
brew install ccache llvm
464+
brew install ccache llvm@19 lld
464465
- name: Compute cache keys
465466
id: cache-key
466467
run: |

.github/workflows/integration-tests.yml.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ env:
2727
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
2828
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
2929
TRITON_DISABLE_LINE_INFO: 1
30-
30+
PROTON_SKIP_PC_SAMPLING_TEST: 1
3131

3232
jobs:
3333
Runner-Preparation:
@@ -439,7 +439,7 @@ jobs:
439439
- name: Install brew dependencies
440440
run: |
441441
brew update
442-
brew install ccache llvm
442+
brew install ccache llvm@19 lld
443443

444444
- *compute-cache-keys-step
445445
- *cache-build-dependencies-step

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
102102
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
103103
target_link_libraries(triton-tensor-layout PRIVATE
104104
TritonGPUIR
105+
TritonNvidiaGPUIR
105106
${triton_libs}
107+
${conversion_libs}
108+
${dialect_libs}
109+
TritonTestAnalysis
106110
)
107111

108112
add_llvm_executable(triton-translate

bin/triton-tensor-layout.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
#include "RegisterTritonDialects.h"
2+
13
#include "mlir/AsmParser/AsmParser.h"
24
#include "mlir/AsmParser/AsmParserState.h"
35
#include "mlir/IR/MLIRContext.h"
46

57
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
69

710
#include "llvm/Support/CommandLine.h"
811
#include "llvm/Support/ErrorOr.h"
@@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
114117
return failure();
115118
}
116119

117-
auto printLambda = [&](StringRef name, Attribute attr) {
120+
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
118121
ss << "Print layout attribute: #" << name << " = " << attr << "\n";
119122

120123
auto rankedTensorTy = RankedTensorType::get(
@@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
155158
if (layoutAttrStr.empty())
156159
return success();
157160

158-
Attribute layout = parseAttribute(layoutAttrStr, context);
161+
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
159162
if (!layout) {
160163
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
161164
return failure();
@@ -178,8 +181,7 @@ int main(int argc, char **argv) {
178181
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");
179182

180183
DialectRegistry registry;
181-
// Register all dialects that can print tensor layout.
182-
registry.insert<triton::gpu::TritonGPUDialect>();
184+
registerTritonDialects(registry);
183185

184186
MLIRContext ctx(registry);
185187
ctx.loadAllAvailableDialects();
@@ -189,7 +191,7 @@ int main(int argc, char **argv) {
189191
return 1;
190192
}
191193

192-
Type parsedTy = parseType(TensorStr, &ctx);
194+
mlir::Type parsedTy = parseType(TensorStr, &ctx);
193195
if (!parsedTy) {
194196
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
195197
<< "\n";

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ class TargetInfoBase {
5757
unsigned numLaneToReduce,
5858
unsigned interleave) const = 0;
5959

60-
// TODO (Keren): Remove this function once layout conversion using stmatrix is
61-
// handled by Linear Layout.
62-
virtual bool processReplicaUsingStMatrix(
63-
RewriterBase &rewriter, Location loc, Value smemBase,
64-
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
65-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
66-
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
67-
int swizzleByteWidth = 0) const = 0;
68-
6960
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7061
// Emits LLVM code with |rewriter| to print a message following the given
7162
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
7575
SmallVector<unsigned>
7676
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
7777

78+
// Returns the dimensions of the tensor from minor (fast-varying) to
79+
// major (slow-varying). For blocked, mma, and dotOperand layouts,
80+
// though the elements are in registers, the order refers to memory
81+
// layout of the original tensor in global memory.
82+
// For shared Layout, the order refers to which dimension of the original tensor
83+
// is contiguous in shared memory.
84+
SmallVector<unsigned> getOrder(Attribute layout);
85+
86+
// Returns the dimensions along which warpId's are distributed.
87+
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
88+
// tells there are 2 warps along dim0 and 4 warps along dim1.
89+
// warpOrder tells the specific order when distributing warp IDs.
90+
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
91+
// [warp0 warp2 warp4 warp6]
92+
// [warp1 warp3 warp5 warp7]
93+
// Note that in most cases, getWarpOrder and getOrder return the same results.
94+
// But this is not guaranteed.
7895
SmallVector<unsigned> getWarpOrder(Attribute layout);
7996

80-
SmallVector<unsigned> getOrder(Attribute layout);
97+
// Returns the dimensions along which threadId's are distributed.
98+
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
99+
// distribution in the warp.
100+
// Note that, in most cases, getThreadOrder and getOrder return the same
101+
// results. But this is not guaranteed. One exception is mfma.transposed layout,
102+
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
103+
SmallVector<unsigned> getThreadOrder(Attribute layout);
81104

82105
CTALayoutAttr getCTALayout(Attribute layout);
83106

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
113113
// row0 reg[0-1] reg[4-5]
114114
// row8 reg[2-3] reg[6-7]
115115
//
116+
// When `swizzleByteSize` is non-zero, the layout is constructed
117+
// differently due to leading dimension offset and swizzling.
118+
// There are two key concepts to understand:
119+
//
120+
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
121+
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
122+
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
123+
// rows to optimize memory access.
124+
//
125+
// - Concept 1: Chunks
126+
//
127+
// In the swizzled layout, the leading dimension is strided by
128+
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
129+
// spans a certain number of columns.
130+
//
131+
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
132+
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
133+
// elements * 2 bytes per element = 32 bytes per row).
134+
//
135+
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
136+
// calculated as:
137+
//
138+
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
139+
// 32 bytes = 4 tiles
140+
//
141+
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
142+
// (since each tile is 16 columns):
143+
//
144+
// col0-15 col16-31 col32-47 col48-63
145+
// row0-15 tile0 tile1 tile2 tile3
146+
//
147+
// For a tensor of size 128x128 elements (#rows x #columns), and each element
148+
// being 16 bits, the tensor can be divided into multiple chunks both
149+
// horizontally and vertically. Chunks are stored in memory in a "column-major"
150+
// order based on chunks, meaning chunk1's address follows chunk0's.
151+
//
152+
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
153+
// rows (rows per tile) and 128 columns (the width of two chunks). This results
154+
// in each warp handling one horizontal slice of the tensor.
155+
//
156+
// The overall layout can be visualized as:
157+
//
158+
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
159+
// columns 0-63 columns 64-127
160+
// warp0 | rows 0-15 chunk0 chunk8
161+
// warp1 | rows 16-31 chunk1 chunk9
162+
// warp2 | rows 32-47 chunk2 chunk10
163+
// warp3 | rows 48-63 chunk3 chunk11
164+
// warp4 | rows 64-79 chunk4 chunk12
165+
// warp5 | rows 80-95 chunk5 chunk13
166+
// warp6 | rows 96-111 chunk6 chunk14
167+
// warp7 | rows 112-127 chunk7 chunk15
168+
//
169+
// - Concept 2: Swizzling within tiles
170+
//
171+
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
172+
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
173+
// level of each 16x16 tile rather than the entire tensor.
174+
//
175+
// Key parameters for swizzling:
176+
//
177+
// - `perPhase`: The number of rows over which to apply a XOR operation at
178+
// each phase.
179+
// - `maxPhase`: The total number of phases.
180+
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
181+
// because `stmatrix` stores 8 contiguous elements per thread.
182+
//
183+
// The offset of each element within a tile is calculated using the formula:
184+
//
185+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
186+
// maxPhase)) * elementSize
187+
//
188+
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
189+
// elements).
190+
//
191+
// For example, consider the element at index `(row=1, col=0)` in chunk0:
192+
//
193+
// Without swizzling:
194+
//
195+
// offset = row * swizzleByteSize + col * elementSize
196+
// = 1 * 128 bytes + 0 * 2 bytes
197+
// = 128 bytes
198+
//
199+
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
200+
//
201+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
202+
// maxPhase)) * elementSize
203+
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
204+
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
205+
// = 128 bytes + 8 * 2 bytes
206+
// = 128 bytes + 16 bytes
207+
// = 144 bytes
208+
//
209+
// This swizzling ensures that elements are stored in a way that optimizes for
210+
// memory bandwidth and reduces bank conflicts.
211+
//
212+
// - Verification through Linear Layout
213+
//
214+
// We can verify the offsets with the following outputs of the corresponding
215+
// linear layout, where each element is 16 bits (2 bytes):
216+
//
217+
// - register=1 -> offset=1
218+
// register=2 -> offset=2
219+
// register=4 -> offset=4
220+
// register=8 -> offset=16
221+
// register=16 -> offset=32
222+
// register=32 -> offset=8192
223+
// - lane=1 -> offset=72
224+
// lane=2 -> offset=144
225+
// lane=4 -> offset=288
226+
// lane=8 -> offset=512
227+
// lane=16 -> offset=8
228+
// - warp=1 -> offset=1024
229+
// warp=2 -> offset=2048
230+
// warp=4 -> offset=4096
231+
//
232+
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
233+
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
234+
// matches our earlier calculation.
235+
//
116236
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
117237
// bit width of the tensor in the future to support more flexible tensor
118238
// encodings
119-
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
120-
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
121-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
239+
std::optional<LinearLayout>
240+
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
241+
ArrayRef<unsigned> repShape,
242+
ArrayRef<unsigned> paddedRepShape,
243+
ArrayRef<unsigned> order, int swizzleByteSize);
122244
} // namespace mlir::triton::gpu
123245

124246
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
3838
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
3939
return getParentOrder(sliceEncoding.getParent());
4040
}
41-
return getOrder(layout);
41+
return getThreadOrder(layout);
4242
}
4343

4444
} // namespace
@@ -77,7 +77,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
7777
threadOffset = threadsPerWarp[sliceLayout.getDim()];
7878
} else {
7979
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
80-
auto order = getOrder(srcLayout);
80+
auto order = getThreadOrder(srcLayout);
8181
for (unsigned i = 0; i < order.size(); i++) {
8282
if (order[i] == axis)
8383
break;

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
215215
if (repId != 0) {
216216
barrier();
217217
}
218-
auto successful = targetInfo.processReplicaUsingStMatrix(
219-
rewriter, loc, smemBase, vals, srcTy,
220-
getTypeConverter()->convertType(srcTy.getElementType()),
221-
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
222-
if (!successful) {
223-
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
224-
multiDimRepId, inVec, paddedRepShape, origRepShape,
225-
outOrd, vals, smemBase);
226-
}
218+
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
219+
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
220+
vals, smemBase);
227221
barrier();
228222
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
229223
multiDimRepId, outVec, paddedRepShape, origRepShape,
@@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
483477
// Input dims: [reg, lane, warp]
484478
// Output dims: [offset, iteration]
485479
std::optional<LinearLayout> shmemStoreLayout =
486-
chooseStMatrixLayoutForRegToRegConversion(
487-
ctx, op.getSrc().getType(), scratchConfig.repShape,
488-
scratchConfig.paddedRepShape, scratchConfig.order);
480+
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
481+
scratchConfig.paddedRepShape, scratchConfig.order,
482+
/*swizzleByteSize=*/0);
489483
bool isStMatrix = shmemStoreLayout.has_value();
490484
if (!isStMatrix) {
491485
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119-
// TODO: do we need to check if src is shared ?
120119
if (isa<SharedEncodingAttr>(srcLayout) &&
121120
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
122121
dstLayout)) {

0 commit comments

Comments
 (0)