Skip to content

Commit 06941f4

Browse files
authored
[LAYOUTS] Implement generically getUniqueContigPerThread (#5840)
This allows vectorisation on global loads and smem in some cases we didn't use it before, as we now compute the order of the elements looking at the actual LinearLayout associated to the given shape of the tensor, which is quite neat. We end up touching a few things in the Scan lowering as BlockedLayouts when converted to LinearEncodings may not have the same order on elems/threads/warps. This is a feature, not a bug, as it takes us closer to supporting arbitrary LinearEncodings within the tt.scan op.
1 parent e7072a3 commit 06941f4

File tree

8 files changed

+152
-187
lines changed

8 files changed

+152
-187
lines changed

include/triton/Analysis/Utility.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,26 @@ class ScanLoweringHelper {
8989
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
9090
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
9191
srcShape = firstTy.getShape();
92-
srcEncoding = firstTy.getEncoding();
92+
legacyEncoding = firstTy.getEncoding();
93+
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
9394
srcElementTypes = op.getElementTypes();
95+
// The codegen does not support different element/thread/warp order so
96+
// we choose one a priori. We choose that of the blocked encoding.
97+
// When we generalise this code to other layouts we'll probably need to
98+
// get rid of all this logic and the *Stride auxiliary methods
99+
// and replace them by transposes and reshapes on the LinearLayout
100+
if (auto blockedEncoding =
101+
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
102+
order = llvm::to_vector(blockedEncoding.getOrder());
103+
} else {
104+
order = srcEncoding.getOrder();
105+
}
94106

95107
for (const auto &t : op.getInputTypes()) {
96108
if (t.getShape() != srcShape) {
97109
op.emitError() << "shape mismatch";
98110
}
99-
if (t.getEncoding() != srcEncoding) {
111+
if (t.getEncoding() != legacyEncoding) {
100112
op.emitError() << "encoding mismatch";
101113
}
102114
}
@@ -111,12 +123,8 @@ class ScanLoweringHelper {
111123
unsigned getNonAxisNumThreadsPerWarp();
112124
// Return the flat numbers of threads computing independent scan results.
113125
unsigned getNonAxisNumThreadsPerCTA();
114-
// Return the number of warps per CTA along axis dim.
115-
unsigned getAxisNumWarps();
116126
// Return the number of warps per CTA along axis dim with unique data.
117127
unsigned getAxisNumWarpsWithUniqueData();
118-
// Return the number of threads per warp along axis dim.
119-
unsigned getAxisNumThreadsPerWarp();
120128
// Return the number of threads per warp along axis dim with unique data.
121129
unsigned getAxisNumThreadsPerWarpWithUniqueData();
122130
// Return the number of blocks along axis dim.
@@ -139,18 +147,20 @@ class ScanLoweringHelper {
139147
Location getLoc() { return scanOp.getLoc(); }
140148
unsigned getAxis() { return scanOp.getAxis(); }
141149
bool getReverse() { return scanOp.getReverse(); }
142-
triton::gpu::BlockedEncodingAttr getEncoding();
150+
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
143151
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
144152
unsigned getNumOperands() { return scanOp.getNumOperands(); }
145153
SmallVector<Type> getElementTypes() { return srcElementTypes; }
146-
Attribute getSrcLayout() { return srcEncoding; }
154+
SmallVector<unsigned> getOrder() { return order; }
147155
Region &getCombineOp();
148156

149157
private:
150158
triton::ScanOp scanOp;
151-
Attribute srcEncoding;
159+
triton::gpu::LinearEncodingAttr srcEncoding;
160+
Attribute legacyEncoding;
152161
llvm::ArrayRef<int64_t> srcShape;
153162
SmallVector<Type> srcElementTypes;
163+
SmallVector<unsigned> order;
154164
};
155165

156166
// Helper class for lowering `tt.gather` operations. This class shares lowering

lib/Analysis/Allocation.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace triton {
2828

2929
// Bitwidth of pointers
3030
constexpr int kPtrBitWidth = 64;
31+
// Max shmem LDS/STS instruction in bits
32+
constexpr int kMaxShmemVecBitLength = 128;
3133

3234
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
3335
RankedTensorType dstTy) {
@@ -79,15 +81,17 @@ std::pair<unsigned, unsigned>
7981
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
8082
Attribute srcLayout = srcTy.getEncoding();
8183
Attribute dstLayout = dstTy.getEncoding();
82-
const auto &inOrd = gpu::getOrder(srcLayout);
83-
const auto &outOrd = gpu::getOrder(dstLayout);
84+
85+
auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
86+
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
87+
auto inOrd = srcLinAttr.getOrder();
88+
auto outOrd = dstLinAttr.getOrder();
89+
8490
unsigned rank = srcTy.getRank();
8591

86-
unsigned srcContigPerThread =
87-
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
88-
unsigned dstContigPerThread =
89-
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
90-
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
92+
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
93+
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
94+
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
9195
// that we cannot do vectorization.
9296
unsigned innerDim = rank - 1;
9397
unsigned inVec = outOrd[0] != innerDim ? 1
@@ -117,8 +121,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
117121
Attribute dstLayout = dstTy.getEncoding();
118122

119123
assert(cvtNeedsSharedMemory(srcTy, dstTy));
120-
121-
const auto &outOrd = gpu::getOrder(dstLayout);
124+
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
122125
scratchConfig.order = outOrd;
123126

124127
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
@@ -129,6 +132,18 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
129132
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130133
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131134
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
135+
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
136+
// is the max vectorisation
137+
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
138+
? kPtrBitWidth
139+
: srcTy.getElementTypeBitWidth();
140+
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
141+
? kPtrBitWidth
142+
: dstTy.getElementTypeBitWidth();
143+
scratchConfig.inVec =
144+
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
145+
scratchConfig.outVec =
146+
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);
132147

133148
// No padding is required if the tensor is 1-D, or if all dimensions except
134149
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,15 +1222,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12221222
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
12231223
if (!tensorTy)
12241224
return 1;
1225-
auto layout = tensorTy.getEncoding();
12261225

1227-
// Here order should be ordered by contiguous first, so the first element
1228-
// should have the largest contiguous.
1229-
auto order = triton::gpu::getOrder(layout);
1226+
// FIXME: This is not as good as it could be, as we don't need to restrict
1227+
// the analysis to one dimension. We should determine contiguity on the
1228+
// flattenOuts() layout
1229+
auto linAttr =
1230+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1231+
auto order = linAttr.getOrder();
12301232
unsigned align = getPtrAlignment(ptr);
12311233

1232-
auto uniqueContigPerThread =
1233-
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
1234+
auto uniqueContigPerThread = linAttr.getContigPerThread();
12341235
assert(order[0] < uniqueContigPerThread.size() &&
12351236
"Unexpected uniqueContigPerThread size");
12361237
unsigned contiguity = uniqueContigPerThread[order[0]];
@@ -1247,8 +1248,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12471248
auto *axisInfo = getAxisInfo(ptr);
12481249
if (!axisInfo)
12491250
return 1;
1250-
auto layout = tensorTy.getEncoding();
1251-
auto order = triton::gpu::getOrder(layout);
1251+
auto linAttr =
1252+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1253+
auto order = linAttr.getOrder();
12521254
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12531255
auto maxContig = axisInfo->getContiguity(order[0]);
12541256
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1275,7 +1277,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12751277
auto *axisInfo = getAxisInfo(mask);
12761278
if (!axisInfo)
12771279
return 1;
1278-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1280+
auto linAttr =
1281+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1282+
auto maskOrder = linAttr.getOrder();
12791283
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12801284
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12811285
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,18 @@
2222
#include "triton/Tools/Sys/GetEnv.hpp"
2323

2424
namespace mlir {
25-
namespace {
2625

2726
using namespace triton;
2827
using namespace triton::gpu;
2928

30-
int getParentAxis(Attribute layout, int axis) {
31-
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
32-
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
33-
return getParentAxis(sliceEncoding.getParent(), axis);
34-
}
35-
return axis;
36-
}
37-
38-
SmallVector<unsigned> getParentOrder(Attribute layout) {
39-
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
40-
return getParentOrder(sliceEncoding.getParent());
41-
}
42-
return getThreadOrder(layout);
43-
}
44-
45-
} // namespace
46-
4729
// TODO(jlebar): Move this class into namespace triton.
4830
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
49-
return getParentAxis(getSrcLayout(), axis) ==
50-
getParentOrder(getSrcLayout())[0];
31+
auto linearEncoding = toLinearEncoding(getSrcLayout(), getSrcShape());
32+
return linearEncoding.getOrder()[0] == axis;
5133
}
5234

5335
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
54-
auto srcLayout = getSrcLayout();
55-
auto order = getOrder(srcLayout);
36+
auto order = toLinearEncoding(getSrcLayout(), getSrcShape()).getOrder();
5637
auto it = std::find(order.begin(), order.end(), axis);
5738
// delete the axis from order
5839
order.erase(it);
@@ -219,69 +200,59 @@ bool ReduceOpHelper::isSupportedLayout() {
219200
}
220201

221202
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
222-
return getEncoding().getSizePerThread()[getAxis()];
203+
return getEncoding().getContigPerThread()[getAxis()];
223204
}
224205

225206
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
226-
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
227-
sizePerThreads[getAxis()] = 1;
228-
return product<unsigned>(sizePerThreads);
207+
auto contigPerThread = getEncoding().getContigPerThread();
208+
contigPerThread[getAxis()] = 1;
209+
return product<unsigned>(contigPerThread);
229210
}
230211

231212
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
232213

233-
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
234-
return getThreadsPerWarp(getEncoding())[getAxis()];
235-
}
236-
237214
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
238-
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
215+
return getEncoding().getThreadsPerWarp()[getAxis()];
239216
}
240217

241218
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
242-
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
243-
threadsPerWarp[getAxis()] = 1;
244-
return product<unsigned>(threadsPerWarp);
219+
auto nThreads = product(getEncoding().getThreadsPerWarp());
220+
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
245221
}
246222

247223
// Return the flat numbers of threads computing independent scan results.
248224
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
249-
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
250-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
251-
warpsPerCTA[getAxis()] = 1;
252-
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
253-
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
254-
}
255-
256-
unsigned ScanLoweringHelper::getAxisNumWarps() {
257-
return getWarpsPerCTA(getEncoding())[getAxis()];
225+
auto nWarps = product(getEncoding().getWarpsPerCTA());
226+
return (nWarps / getAxisNumWarpsWithUniqueData()) *
227+
getNonAxisNumThreadsPerWarp();
258228
}
259229

260230
unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
261-
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
231+
return getEncoding().getWarpsPerCTA()[getAxis()];
262232
}
263233

264234
unsigned ScanLoweringHelper::getAxisNumBlocks() {
265-
auto sizePerThreads = getSizePerThread(getEncoding());
235+
auto contigPerThread = getEncoding().getContigPerThread();
266236
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
267237
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
268238
unsigned axis = getAxis();
269239
return ceil<unsigned>(
270240
getShape()[axis],
271-
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
241+
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
272242
}
273243

274244
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
275-
auto sizePerThreads = getSizePerThread(getEncoding());
245+
auto contigPerThread = getEncoding().getContigPerThread();
276246
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
277247
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
248+
auto rank = contigPerThread.size();
278249
unsigned axis = getAxis();
279250
unsigned numBlocks = 1;
280-
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
251+
for (unsigned i = 0; i < rank; i++) {
281252
if (i == axis)
282253
continue;
283254
numBlocks *=
284-
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
255+
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
285256
warpsPerCTA[i]));
286257
}
287258
return numBlocks;
@@ -290,7 +261,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
290261
bool ScanLoweringHelper::isSupported() {
291262
// TODO: Support the following cases:
292263
// 1. Scan on non-blocking encodings
293-
if (!isa<BlockedEncodingAttr>(srcEncoding))
264+
if (!isa<BlockedEncodingAttr>(legacyEncoding))
294265
return false;
295266
return true;
296267
}
@@ -578,42 +549,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
578549
return ret;
579550
}
580551

581-
BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
582-
return cast<BlockedEncodingAttr>(srcEncoding);
583-
}
584-
585552
unsigned ScanLoweringHelper::getAxisElementStride() {
586-
auto order = getOrder(getEncoding());
553+
auto order = getOrder();
587554
unsigned stride = 1;
588555
for (unsigned dim : order) {
589556
if (dim == getAxis())
590557
return stride;
591-
stride *= getContigPerThread(getEncoding())[dim];
558+
stride *= getEncoding().getContigPerThread()[dim];
592559
}
593560
llvm_unreachable("Axis not found in order");
594561
}
595562

596563
unsigned ScanLoweringHelper::getAxisThreadStride() {
597-
auto order = getOrder(getEncoding());
564+
auto encoding = getEncoding();
565+
auto kThread = StringAttr::get(encoding.getContext(), "lane");
566+
// OOOGHHH This is nasty. We should implement this lowering via LLs natively
567+
// to avoid this
568+
auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false);
569+
auto order = getOrder();
598570
unsigned stride = 1;
599571
for (unsigned dim : order) {
600572
if (dim == getAxis())
601573
return stride;
602-
stride *= getEncoding().getThreadsPerWarp()[dim];
574+
stride *= threadsPerWarp[dim];
603575
}
604576
llvm_unreachable("Axis not found in order");
605577
}
606578

607579
unsigned ScanLoweringHelper::getAxisBlockStride() {
608-
auto order = getOrder(getEncoding());
580+
auto order = getOrder();
609581
unsigned stride = 1;
610-
auto sizePerThreads = getSizePerThread(getEncoding());
582+
auto contigPerThread = getEncoding().getContigPerThread();
611583
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
612584
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
613585
for (unsigned dim : order) {
614586
if (dim == getAxis())
615587
return stride;
616-
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
588+
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
617589
threadsPerWarp[dim] *
618590
warpsPerCTA[dim]);
619591
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ using ::mlir::LLVM::linearize;
99
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1010
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
1111
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
12-
using ::mlir::triton::gpu::getContigPerThread;
1312
using ::mlir::triton::gpu::getOrder;
1413
using ::mlir::triton::gpu::getShapePerCTA;
1514
using ::mlir::triton::gpu::getSizePerThread;

0 commit comments

Comments
 (0)