Skip to content

Commit dad6800

Browse files
Merge commit '06941f490322679231aae20bfe20b61e9885ad48'
2 parents 6b1642e + 06941f4 commit dad6800

File tree

14 files changed

+332
-217
lines changed

14 files changed

+332
-217
lines changed

include/triton/Analysis/Utility.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,26 @@ class ScanLoweringHelper {
8989
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
9090
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
9191
srcShape = firstTy.getShape();
92-
srcEncoding = firstTy.getEncoding();
92+
legacyEncoding = firstTy.getEncoding();
93+
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
9394
srcElementTypes = op.getElementTypes();
95+
// The codegen does not support different element/thread/warp order so
96+
// we choose one a priori. We choose that of the blocked encoding.
97+
// When we generalise this code to other layouts we'll probably need to
98+
// get rid of all this logic and the *Stride auxiliary methods
99+
// and replace them by transposes and reshapes on the LinearLayout
100+
if (auto blockedEncoding =
101+
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
102+
order = llvm::to_vector(blockedEncoding.getOrder());
103+
} else {
104+
order = srcEncoding.getOrder();
105+
}
94106

95107
for (const auto &t : op.getInputTypes()) {
96108
if (t.getShape() != srcShape) {
97109
op.emitError() << "shape mismatch";
98110
}
99-
if (t.getEncoding() != srcEncoding) {
111+
if (t.getEncoding() != legacyEncoding) {
100112
op.emitError() << "encoding mismatch";
101113
}
102114
}
@@ -111,12 +123,8 @@ class ScanLoweringHelper {
111123
unsigned getNonAxisNumThreadsPerWarp();
112124
// Return the flat numbers of threads computing independent scan results.
113125
unsigned getNonAxisNumThreadsPerCTA();
114-
// Return the number of warps per CTA along axis dim.
115-
unsigned getAxisNumWarps();
116126
// Return the number of warps per CTA along axis dim with unique data.
117127
unsigned getAxisNumWarpsWithUniqueData();
118-
// Return the number of threads per warp along axis dim.
119-
unsigned getAxisNumThreadsPerWarp();
120128
// Return the number of threads per warp along axis dim with unique data.
121129
unsigned getAxisNumThreadsPerWarpWithUniqueData();
122130
// Return the number of blocks along axis dim.
@@ -139,18 +147,20 @@ class ScanLoweringHelper {
139147
Location getLoc() { return scanOp.getLoc(); }
140148
unsigned getAxis() { return scanOp.getAxis(); }
141149
bool getReverse() { return scanOp.getReverse(); }
142-
triton::gpu::BlockedEncodingAttr getEncoding();
150+
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
143151
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
144152
unsigned getNumOperands() { return scanOp.getNumOperands(); }
145153
SmallVector<Type> getElementTypes() { return srcElementTypes; }
146-
Attribute getSrcLayout() { return srcEncoding; }
154+
SmallVector<unsigned> getOrder() { return order; }
147155
Region &getCombineOp();
148156

149157
private:
150158
triton::ScanOp scanOp;
151-
Attribute srcEncoding;
159+
triton::gpu::LinearEncodingAttr srcEncoding;
160+
Attribute legacyEncoding;
152161
llvm::ArrayRef<int64_t> srcShape;
153162
SmallVector<Type> srcElementTypes;
163+
SmallVector<unsigned> order;
154164
};
155165

156166
// Helper class for lowering `tt.gather` operations. This class shares lowering

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,10 @@ class CoarseSchedule {
106106
return true;
107107
}
108108

109-
void insertMinimum(Operation *op, int stage, Cluster cluster) {
110-
auto res = opToStageAndCluster.insert({op, {stage, cluster}});
111-
if (res.second) {
112-
return;
113-
}
114-
auto &[existingStage, existingCluster] = res.first->second;
115-
existingStage = std::min(stage, existingStage);
116-
117-
// If existingCluster is reachable from cluster,
118-
// then cluster is earlier in the list
119-
auto it = cluster;
120-
for (auto it = cluster; it != clusters.end(); ++it) {
121-
if (it == existingCluster) {
122-
existingCluster = cluster;
123-
return;
124-
}
125-
}
126-
}
109+
bool insertMinimum(Operation *op, int stage, Cluster cluster);
127110

128111
bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
129-
bool includeArg);
112+
bool includeArg, bool insertIfEarlier = false);
130113

131114
void erase(Operation *op) { opToStageAndCluster.erase(op); }
132115

lib/Analysis/Allocation.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace triton {
2828

2929
// Bitwidth of pointers
3030
constexpr int kPtrBitWidth = 64;
31+
// Max shmem LDS/STS instruction in bits
32+
constexpr int kMaxShmemVecBitLength = 128;
3133

3234
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
3335
RankedTensorType dstTy) {
@@ -79,15 +81,17 @@ std::pair<unsigned, unsigned>
7981
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
8082
Attribute srcLayout = srcTy.getEncoding();
8183
Attribute dstLayout = dstTy.getEncoding();
82-
const auto &inOrd = gpu::getOrder(srcLayout);
83-
const auto &outOrd = gpu::getOrder(dstLayout);
84+
85+
auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
86+
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
87+
auto inOrd = srcLinAttr.getOrder();
88+
auto outOrd = dstLinAttr.getOrder();
89+
8490
unsigned rank = srcTy.getRank();
8591

86-
unsigned srcContigPerThread =
87-
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
88-
unsigned dstContigPerThread =
89-
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
90-
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
92+
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
93+
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
94+
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
9195
// that we cannot do vectorization.
9296
unsigned innerDim = rank - 1;
9397
unsigned inVec = outOrd[0] != innerDim ? 1
@@ -117,8 +121,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
117121
Attribute dstLayout = dstTy.getEncoding();
118122

119123
assert(cvtNeedsSharedMemory(srcTy, dstTy));
120-
121-
const auto &outOrd = gpu::getOrder(dstLayout);
124+
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
122125
scratchConfig.order = outOrd;
123126

124127
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
@@ -129,6 +132,18 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
129132
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130133
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131134
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
135+
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
136+
// is the max vectorisation
137+
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
138+
? kPtrBitWidth
139+
: srcTy.getElementTypeBitWidth();
140+
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
141+
? kPtrBitWidth
142+
: dstTy.getElementTypeBitWidth();
143+
scratchConfig.inVec =
144+
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
145+
scratchConfig.outVec =
146+
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);
132147

133148
// No padding is required if the tensor is 1-D, or if all dimensions except
134149
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,15 +1222,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12221222
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
12231223
if (!tensorTy)
12241224
return 1;
1225-
auto layout = tensorTy.getEncoding();
12261225

1227-
// Here order should be ordered by contiguous first, so the first element
1228-
// should have the largest contiguous.
1229-
auto order = triton::gpu::getOrder(layout);
1226+
// FIXME: This is not as good as it could be, as we don't need to restrict
1227+
// the analysis to one dimension. We should determine contiguity on the
1228+
// flattenOuts() layout
1229+
auto linAttr =
1230+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1231+
auto order = linAttr.getOrder();
12301232
unsigned align = getPtrAlignment(ptr);
12311233

1232-
auto uniqueContigPerThread =
1233-
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
1234+
auto uniqueContigPerThread = linAttr.getContigPerThread();
12341235
assert(order[0] < uniqueContigPerThread.size() &&
12351236
"Unexpected uniqueContigPerThread size");
12361237
unsigned contiguity = uniqueContigPerThread[order[0]];
@@ -1247,8 +1248,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12471248
auto *axisInfo = getAxisInfo(ptr);
12481249
if (!axisInfo)
12491250
return 1;
1250-
auto layout = tensorTy.getEncoding();
1251-
auto order = triton::gpu::getOrder(layout);
1251+
auto linAttr =
1252+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1253+
auto order = linAttr.getOrder();
12521254
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12531255
auto maxContig = axisInfo->getContiguity(order[0]);
12541256
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1275,7 +1277,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12751277
auto *axisInfo = getAxisInfo(mask);
12761278
if (!axisInfo)
12771279
return 1;
1278-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1280+
auto linAttr =
1281+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1282+
auto maskOrder = linAttr.getOrder();
12791283
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12801284
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12811285
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,18 @@
2323
#include "triton/Tools/Sys/GetEnv.hpp"
2424

2525
namespace mlir {
26-
namespace {
2726

2827
using namespace triton;
2928
using namespace triton::gpu;
3029

31-
int getParentAxis(Attribute layout, int axis) {
32-
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
33-
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
34-
return getParentAxis(sliceEncoding.getParent(), axis);
35-
}
36-
return axis;
37-
}
38-
39-
SmallVector<unsigned> getParentOrder(Attribute layout) {
40-
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
41-
return getParentOrder(sliceEncoding.getParent());
42-
}
43-
return getThreadOrder(layout);
44-
}
45-
46-
} // namespace
47-
4830
// TODO(jlebar): Move this class into namespace triton.
4931
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
50-
return getParentAxis(getSrcLayout(), axis) ==
51-
getParentOrder(getSrcLayout())[0];
32+
auto linearEncoding = toLinearEncoding(getSrcLayout(), getSrcShape());
33+
return linearEncoding.getOrder()[0] == axis;
5234
}
5335

5436
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
55-
auto srcLayout = getSrcLayout();
56-
auto order = getOrder(srcLayout);
37+
auto order = toLinearEncoding(getSrcLayout(), getSrcShape()).getOrder();
5738
auto it = std::find(order.begin(), order.end(), axis);
5839
// delete the axis from order
5940
order.erase(it);
@@ -225,69 +206,59 @@ bool ReduceOpHelper::isSupportedLayout() {
225206
}
226207

227208
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
228-
return getEncoding().getSizePerThread()[getAxis()];
209+
return getEncoding().getContigPerThread()[getAxis()];
229210
}
230211

231212
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
232-
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
233-
sizePerThreads[getAxis()] = 1;
234-
return product<unsigned>(sizePerThreads);
213+
auto contigPerThread = getEncoding().getContigPerThread();
214+
contigPerThread[getAxis()] = 1;
215+
return product<unsigned>(contigPerThread);
235216
}
236217

237218
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
238219

239-
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
240-
return getThreadsPerWarp(getEncoding())[getAxis()];
241-
}
242-
243220
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
244-
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
221+
return getEncoding().getThreadsPerWarp()[getAxis()];
245222
}
246223

247224
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
248-
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
249-
threadsPerWarp[getAxis()] = 1;
250-
return product<unsigned>(threadsPerWarp);
225+
auto nThreads = product(getEncoding().getThreadsPerWarp());
226+
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
251227
}
252228

253229
// Return the flat numbers of threads computing independent scan results.
254230
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
255-
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
256-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
257-
warpsPerCTA[getAxis()] = 1;
258-
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
259-
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
260-
}
261-
262-
unsigned ScanLoweringHelper::getAxisNumWarps() {
263-
return getWarpsPerCTA(getEncoding())[getAxis()];
231+
auto nWarps = product(getEncoding().getWarpsPerCTA());
232+
return (nWarps / getAxisNumWarpsWithUniqueData()) *
233+
getNonAxisNumThreadsPerWarp();
264234
}
265235

266236
unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
267-
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
237+
return getEncoding().getWarpsPerCTA()[getAxis()];
268238
}
269239

270240
unsigned ScanLoweringHelper::getAxisNumBlocks() {
271-
auto sizePerThreads = getSizePerThread(getEncoding());
241+
auto contigPerThread = getEncoding().getContigPerThread();
272242
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
273243
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
274244
unsigned axis = getAxis();
275245
return ceil<unsigned>(
276246
getShape()[axis],
277-
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
247+
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
278248
}
279249

280250
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
281-
auto sizePerThreads = getSizePerThread(getEncoding());
251+
auto contigPerThread = getEncoding().getContigPerThread();
282252
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
283253
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
254+
auto rank = contigPerThread.size();
284255
unsigned axis = getAxis();
285256
unsigned numBlocks = 1;
286-
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
257+
for (unsigned i = 0; i < rank; i++) {
287258
if (i == axis)
288259
continue;
289260
numBlocks *=
290-
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
261+
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
291262
warpsPerCTA[i]));
292263
}
293264
return numBlocks;
@@ -296,7 +267,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
296267
bool ScanLoweringHelper::isSupported() {
297268
// TODO: Support the following cases:
298269
// 1. Scan on non-blocking encodings
299-
if (!isa<BlockedEncodingAttr>(srcEncoding))
270+
if (!isa<BlockedEncodingAttr>(legacyEncoding))
300271
return false;
301272
return true;
302273
}
@@ -584,42 +555,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
584555
return ret;
585556
}
586557

587-
BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
588-
return cast<BlockedEncodingAttr>(srcEncoding);
589-
}
590-
591558
unsigned ScanLoweringHelper::getAxisElementStride() {
592-
auto order = getOrder(getEncoding());
559+
auto order = getOrder();
593560
unsigned stride = 1;
594561
for (unsigned dim : order) {
595562
if (dim == getAxis())
596563
return stride;
597-
stride *= getContigPerThread(getEncoding())[dim];
564+
stride *= getEncoding().getContigPerThread()[dim];
598565
}
599566
llvm_unreachable("Axis not found in order");
600567
}
601568

602569
unsigned ScanLoweringHelper::getAxisThreadStride() {
603-
auto order = getOrder(getEncoding());
570+
auto encoding = getEncoding();
571+
auto kThread = StringAttr::get(encoding.getContext(), "lane");
572+
// OOOGHHH This is nasty. We should implement this lowering via LLs natively
573+
// to avoid this
574+
auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false);
575+
auto order = getOrder();
604576
unsigned stride = 1;
605577
for (unsigned dim : order) {
606578
if (dim == getAxis())
607579
return stride;
608-
stride *= getEncoding().getThreadsPerWarp()[dim];
580+
stride *= threadsPerWarp[dim];
609581
}
610582
llvm_unreachable("Axis not found in order");
611583
}
612584

613585
unsigned ScanLoweringHelper::getAxisBlockStride() {
614-
auto order = getOrder(getEncoding());
586+
auto order = getOrder();
615587
unsigned stride = 1;
616-
auto sizePerThreads = getSizePerThread(getEncoding());
588+
auto contigPerThread = getEncoding().getContigPerThread();
617589
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
618590
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
619591
for (unsigned dim : order) {
620592
if (dim == getAxis())
621593
return stride;
622-
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
594+
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
623595
threadsPerWarp[dim] *
624596
warpsPerCTA[dim]);
625597
}

0 commit comments

Comments
 (0)