Skip to content

Commit dd36563

Browse files
Revert "[LAYOUTS] Implement generically getUniqueContigPerThread (#5840)"
This reverts commit 06941f4.
1 parent dad6800 commit dd36563

File tree

8 files changed

+187
-152
lines changed

8 files changed

+187
-152
lines changed

include/triton/Analysis/Utility.h

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,14 @@ class ScanLoweringHelper {
8989
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
9090
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
9191
srcShape = firstTy.getShape();
92-
legacyEncoding = firstTy.getEncoding();
93-
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
92+
srcEncoding = firstTy.getEncoding();
9493
srcElementTypes = op.getElementTypes();
95-
// The codegen does not support different element/thread/warp order so
96-
// we choose one a priori. We choose that of the blocked encoding.
97-
// When we generalise this code to other layouts we'll probably need to
98-
// get rid of all this logic and the *Stride auxiliary methods
99-
// and replace them by transposes and reshapes on the LinearLayout
100-
if (auto blockedEncoding =
101-
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
102-
order = llvm::to_vector(blockedEncoding.getOrder());
103-
} else {
104-
order = srcEncoding.getOrder();
105-
}
10694

10795
for (const auto &t : op.getInputTypes()) {
10896
if (t.getShape() != srcShape) {
10997
op.emitError() << "shape mismatch";
11098
}
111-
if (t.getEncoding() != legacyEncoding) {
99+
if (t.getEncoding() != srcEncoding) {
112100
op.emitError() << "encoding mismatch";
113101
}
114102
}
@@ -123,8 +111,12 @@ class ScanLoweringHelper {
123111
unsigned getNonAxisNumThreadsPerWarp();
124112
// Return the flat numbers of threads computing independent scan results.
125113
unsigned getNonAxisNumThreadsPerCTA();
114+
// Return the number of warps per CTA along axis dim.
115+
unsigned getAxisNumWarps();
126116
// Return the number of warps per CTA along axis dim with unique data.
127117
unsigned getAxisNumWarpsWithUniqueData();
118+
// Return the number of threads per warp along axis dim.
119+
unsigned getAxisNumThreadsPerWarp();
128120
// Return the number of threads per warp along axis dim with unique data.
129121
unsigned getAxisNumThreadsPerWarpWithUniqueData();
130122
// Return the number of blocks along axis dim.
@@ -147,20 +139,18 @@ class ScanLoweringHelper {
147139
Location getLoc() { return scanOp.getLoc(); }
148140
unsigned getAxis() { return scanOp.getAxis(); }
149141
bool getReverse() { return scanOp.getReverse(); }
150-
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
142+
triton::gpu::BlockedEncodingAttr getEncoding();
151143
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
152144
unsigned getNumOperands() { return scanOp.getNumOperands(); }
153145
SmallVector<Type> getElementTypes() { return srcElementTypes; }
154-
SmallVector<unsigned> getOrder() { return order; }
146+
Attribute getSrcLayout() { return srcEncoding; }
155147
Region &getCombineOp();
156148

157149
private:
158150
triton::ScanOp scanOp;
159-
triton::gpu::LinearEncodingAttr srcEncoding;
160-
Attribute legacyEncoding;
151+
Attribute srcEncoding;
161152
llvm::ArrayRef<int64_t> srcShape;
162153
SmallVector<Type> srcElementTypes;
163-
SmallVector<unsigned> order;
164154
};
165155

166156
// Helper class for lowering `tt.gather` operations. This class shares lowering

lib/Analysis/Allocation.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ namespace triton {
2828

2929
// Bitwidth of pointers
3030
constexpr int kPtrBitWidth = 64;
31-
// Max shmem LDS/STS instruction in bits
32-
constexpr int kMaxShmemVecBitLength = 128;
3331

3432
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
3533
RankedTensorType dstTy) {
@@ -81,17 +79,15 @@ std::pair<unsigned, unsigned>
8179
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
8280
Attribute srcLayout = srcTy.getEncoding();
8381
Attribute dstLayout = dstTy.getEncoding();
84-
85-
auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
86-
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
87-
auto inOrd = srcLinAttr.getOrder();
88-
auto outOrd = dstLinAttr.getOrder();
89-
82+
const auto &inOrd = gpu::getOrder(srcLayout);
83+
const auto &outOrd = gpu::getOrder(dstLayout);
9084
unsigned rank = srcTy.getRank();
9185

92-
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
93-
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
94-
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
86+
unsigned srcContigPerThread =
87+
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
88+
unsigned dstContigPerThread =
89+
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
90+
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
9591
// that we cannot do vectorization.
9692
unsigned innerDim = rank - 1;
9793
unsigned inVec = outOrd[0] != innerDim ? 1
@@ -121,7 +117,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
121117
Attribute dstLayout = dstTy.getEncoding();
122118

123119
assert(cvtNeedsSharedMemory(srcTy, dstTy));
124-
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
120+
121+
const auto &outOrd = gpu::getOrder(dstLayout);
125122
scratchConfig.order = outOrd;
126123

127124
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
@@ -132,18 +129,6 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
132129
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
133130
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
134131
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
135-
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
136-
// is the max vectorisation
137-
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
138-
? kPtrBitWidth
139-
: srcTy.getElementTypeBitWidth();
140-
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
141-
? kPtrBitWidth
142-
: dstTy.getElementTypeBitWidth();
143-
scratchConfig.inVec =
144-
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
145-
scratchConfig.outVec =
146-
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);
147132

148133
// No padding is required if the tensor is 1-D, or if all dimensions except
149134
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,16 +1222,15 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12221222
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
12231223
if (!tensorTy)
12241224
return 1;
1225+
auto layout = tensorTy.getEncoding();
12251226

1226-
// FIXME: This is not as good as it could be, as we don't need to restrict
1227-
// the analysis to one dimension. We should determine contiguity on the
1228-
// flattenOuts() layout
1229-
auto linAttr =
1230-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1231-
auto order = linAttr.getOrder();
1227+
// Here order should be ordered by contiguous first, so the first element
1228+
// should have the largest contiguous.
1229+
auto order = triton::gpu::getOrder(layout);
12321230
unsigned align = getPtrAlignment(ptr);
12331231

1234-
auto uniqueContigPerThread = linAttr.getContigPerThread();
1232+
auto uniqueContigPerThread =
1233+
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
12351234
assert(order[0] < uniqueContigPerThread.size() &&
12361235
"Unexpected uniqueContigPerThread size");
12371236
unsigned contiguity = uniqueContigPerThread[order[0]];
@@ -1248,9 +1247,8 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12481247
auto *axisInfo = getAxisInfo(ptr);
12491248
if (!axisInfo)
12501249
return 1;
1251-
auto linAttr =
1252-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1253-
auto order = linAttr.getOrder();
1250+
auto layout = tensorTy.getEncoding();
1251+
auto order = triton::gpu::getOrder(layout);
12541252
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12551253
auto maxContig = axisInfo->getContiguity(order[0]);
12561254
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1277,9 +1275,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12771275
auto *axisInfo = getAxisInfo(mask);
12781276
if (!axisInfo)
12791277
return 1;
1280-
auto linAttr =
1281-
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1282-
auto maskOrder = linAttr.getOrder();
1278+
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
12831279
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12841280
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12851281
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,37 @@
2323
#include "triton/Tools/Sys/GetEnv.hpp"
2424

2525
namespace mlir {
26+
namespace {
2627

2728
using namespace triton;
2829
using namespace triton::gpu;
2930

31+
int getParentAxis(Attribute layout, int axis) {
32+
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
33+
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
34+
return getParentAxis(sliceEncoding.getParent(), axis);
35+
}
36+
return axis;
37+
}
38+
39+
SmallVector<unsigned> getParentOrder(Attribute layout) {
40+
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
41+
return getParentOrder(sliceEncoding.getParent());
42+
}
43+
return getThreadOrder(layout);
44+
}
45+
46+
} // namespace
47+
3048
// TODO(jlebar): Move this class into namespace triton.
3149
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
32-
auto linearEncoding = toLinearEncoding(getSrcLayout(), getSrcShape());
33-
return linearEncoding.getOrder()[0] == axis;
50+
return getParentAxis(getSrcLayout(), axis) ==
51+
getParentOrder(getSrcLayout())[0];
3452
}
3553

3654
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
37-
auto order = toLinearEncoding(getSrcLayout(), getSrcShape()).getOrder();
55+
auto srcLayout = getSrcLayout();
56+
auto order = getOrder(srcLayout);
3857
auto it = std::find(order.begin(), order.end(), axis);
3958
// delete the axis from order
4059
order.erase(it);
@@ -206,59 +225,69 @@ bool ReduceOpHelper::isSupportedLayout() {
206225
}
207226

208227
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
209-
return getEncoding().getContigPerThread()[getAxis()];
228+
return getEncoding().getSizePerThread()[getAxis()];
210229
}
211230

212231
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
213-
auto contigPerThread = getEncoding().getContigPerThread();
214-
contigPerThread[getAxis()] = 1;
215-
return product<unsigned>(contigPerThread);
232+
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
233+
sizePerThreads[getAxis()] = 1;
234+
return product<unsigned>(sizePerThreads);
216235
}
217236

218237
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
219238

239+
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
240+
return getThreadsPerWarp(getEncoding())[getAxis()];
241+
}
242+
220243
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
221-
return getEncoding().getThreadsPerWarp()[getAxis()];
244+
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
222245
}
223246

224247
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
225-
auto nThreads = product(getEncoding().getThreadsPerWarp());
226-
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
248+
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
249+
threadsPerWarp[getAxis()] = 1;
250+
return product<unsigned>(threadsPerWarp);
227251
}
228252

229253
// Return the flat numbers of threads computing independent scan results.
230254
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
231-
auto nWarps = product(getEncoding().getWarpsPerCTA());
232-
return (nWarps / getAxisNumWarpsWithUniqueData()) *
233-
getNonAxisNumThreadsPerWarp();
255+
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
256+
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
257+
warpsPerCTA[getAxis()] = 1;
258+
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
259+
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
260+
}
261+
262+
unsigned ScanLoweringHelper::getAxisNumWarps() {
263+
return getWarpsPerCTA(getEncoding())[getAxis()];
234264
}
235265

236266
unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
237-
return getEncoding().getWarpsPerCTA()[getAxis()];
267+
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
238268
}
239269

240270
unsigned ScanLoweringHelper::getAxisNumBlocks() {
241-
auto contigPerThread = getEncoding().getContigPerThread();
271+
auto sizePerThreads = getSizePerThread(getEncoding());
242272
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
243273
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
244274
unsigned axis = getAxis();
245275
return ceil<unsigned>(
246276
getShape()[axis],
247-
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
277+
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
248278
}
249279

250280
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
251-
auto contigPerThread = getEncoding().getContigPerThread();
281+
auto sizePerThreads = getSizePerThread(getEncoding());
252282
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
253283
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
254-
auto rank = contigPerThread.size();
255284
unsigned axis = getAxis();
256285
unsigned numBlocks = 1;
257-
for (unsigned i = 0; i < rank; i++) {
286+
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
258287
if (i == axis)
259288
continue;
260289
numBlocks *=
261-
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
290+
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
262291
warpsPerCTA[i]));
263292
}
264293
return numBlocks;
@@ -267,7 +296,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
267296
bool ScanLoweringHelper::isSupported() {
268297
// TODO: Support the following cases:
269298
// 1. Scan on non-blocking encodings
270-
if (!isa<BlockedEncodingAttr>(legacyEncoding))
299+
if (!isa<BlockedEncodingAttr>(srcEncoding))
271300
return false;
272301
return true;
273302
}
@@ -555,43 +584,42 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
555584
return ret;
556585
}
557586

587+
BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
588+
return cast<BlockedEncodingAttr>(srcEncoding);
589+
}
590+
558591
unsigned ScanLoweringHelper::getAxisElementStride() {
559-
auto order = getOrder();
592+
auto order = getOrder(getEncoding());
560593
unsigned stride = 1;
561594
for (unsigned dim : order) {
562595
if (dim == getAxis())
563596
return stride;
564-
stride *= getEncoding().getContigPerThread()[dim];
597+
stride *= getContigPerThread(getEncoding())[dim];
565598
}
566599
llvm_unreachable("Axis not found in order");
567600
}
568601

569602
unsigned ScanLoweringHelper::getAxisThreadStride() {
570-
auto encoding = getEncoding();
571-
auto kThread = StringAttr::get(encoding.getContext(), "lane");
572-
// OOOGHHH This is nasty. We should implement this lowering via LLs natively
573-
// to avoid this
574-
auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false);
575-
auto order = getOrder();
603+
auto order = getOrder(getEncoding());
576604
unsigned stride = 1;
577605
for (unsigned dim : order) {
578606
if (dim == getAxis())
579607
return stride;
580-
stride *= threadsPerWarp[dim];
608+
stride *= getEncoding().getThreadsPerWarp()[dim];
581609
}
582610
llvm_unreachable("Axis not found in order");
583611
}
584612

585613
unsigned ScanLoweringHelper::getAxisBlockStride() {
586-
auto order = getOrder();
614+
auto order = getOrder(getEncoding());
587615
unsigned stride = 1;
588-
auto contigPerThread = getEncoding().getContigPerThread();
616+
auto sizePerThreads = getSizePerThread(getEncoding());
589617
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
590618
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
591619
for (unsigned dim : order) {
592620
if (dim == getAxis())
593621
return stride;
594-
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
622+
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
595623
threadsPerWarp[dim] *
596624
warpsPerCTA[dim]);
597625
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ::mlir::LLVM::linearize;
99
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1010
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
1111
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
12+
using ::mlir::triton::gpu::getContigPerThread;
1213
using ::mlir::triton::gpu::getOrder;
1314
using ::mlir::triton::gpu::getShapePerCTA;
1415
using ::mlir::triton::gpu::getSizePerThread;

0 commit comments

Comments
 (0)