Skip to content

Commit 7573b11

Browse files
Merge commit '38a11b859fff79ea214256d3f1cfe43d54e36c2c'
2 parents d7095ce + 38a11b8 commit 7573b11

File tree

18 files changed

+82
-75
lines changed

18 files changed

+82
-75
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b74e588e1f460eb48ceb1a30cf8ac870b7537dcc
1+
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ReduceOpHelper {
6666
// The shape of the shared memory space needed for the reduction.
6767
SmallVector<unsigned> getScratchRepShape();
6868

69-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
69+
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
7070

7171
unsigned getScratchSizeInBytes();
7272

lib/Analysis/Allocation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4646
auto dstShapePerCTATile =
4747
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
4848

49+
assert(srcTy.getRank() == dstTy.getRank() &&
50+
"src and dst must have the same rank");
51+
4952
unsigned rank = dstTy.getRank();
5053
SmallVector<unsigned> repShape(rank);
5154
for (unsigned d = 0; d < rank; ++d) {

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12131213

12141214
// Here order should be ordered by contiguous first, so the first element
12151215
// should have the largest contiguous.
1216-
auto order = triton::gpu::getOrder(layout);
1216+
auto order = triton::gpu::getThreadOrder(layout);
12171217
unsigned align = getPtrAlignment(ptr);
12181218

12191219
auto uniqueContigPerThread =
@@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12351235
if (!axisInfo)
12361236
return 1;
12371237
auto layout = tensorTy.getEncoding();
1238-
auto order = triton::gpu::getOrder(layout);
1238+
auto order = triton::gpu::getThreadOrder(layout);
12391239
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12401240
auto maxContig = axisInfo->getContiguity(order[0]);
12411241
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12621262
auto *axisInfo = getAxisInfo(mask);
12631263
if (!axisInfo)
12641264
return 1;
1265-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1265+
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
12661266
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12671267
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12681268
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ int getParentAxis(Attribute layout, int axis) {
3434
return axis;
3535
}
3636

37-
SmallVector<unsigned> getParentOrder(Attribute layout) {
37+
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
3838
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
39-
return getParentOrder(sliceEncoding.getParent());
39+
return getParentThreadOrder(sliceEncoding.getParent());
4040
}
4141
return getThreadOrder(layout);
4242
}
@@ -46,12 +46,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
4646
// TODO(jlebar): Move this class into namespace triton.
4747
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
4848
return getParentAxis(getSrcLayout(), axis) ==
49-
getParentOrder(getSrcLayout())[0];
49+
getParentThreadOrder(getSrcLayout())[0];
5050
}
5151

52-
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
52+
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
5353
auto srcLayout = getSrcLayout();
54-
auto order = getOrder(srcLayout);
54+
auto order = getThreadOrder(srcLayout);
5555
auto it = std::find(order.begin(), order.end(), axis);
5656
// delete the axis from order
5757
order.erase(it);

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ struct ReduceOpConversion
283283
getMultiDimWarpId(helper, warpId, loc, rewriter);
284284
Value warpIdAxis = multiDimWarpId[axis];
285285

286-
auto smemOrder = helper.getOrderWithAxisAtBeginning();
286+
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
287287
for (auto it : accs) {
288288
const SmallVector<unsigned> &key = it.first;
289289
SmallVector<Value> &acc = it.second;
@@ -370,7 +370,7 @@ struct ReduceOpConversion
370370
Location loc = op.getLoc();
371371
auto srcLayout = helper.getSrcLayout();
372372
auto axis = op.getAxis();
373-
auto smemOrder = helper.getOrderWithAxisAtBeginning();
373+
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
374374
SmallVector<Value> results(op.getNumOperands());
375375
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
376376
auto elemTy = getElementType(op, i);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
181181
int numContiguousValues = 1;
182182
auto encoding = cast<BlockedEncodingAttr>(
183183
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
184-
int splitDim = encoding.getOrder().size() - 1;
185-
for (int i = 0; i < encoding.getOrder().size(); i++) {
186-
if (encoding.getOrder()[i] == splitDim)
184+
int splitDim = encoding.getThreadOrder().size() - 1;
185+
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
186+
if (encoding.getThreadOrder()[i] == splitDim)
187187
break;
188188
numContiguousValues *= encoding.getSizePerThread()[i];
189189
}
@@ -336,7 +336,6 @@ struct BroadcastOpConversion
336336
unsigned rank = srcTy.getRank();
337337
auto typeConverter = getTypeConverter();
338338
assert(rank == resultTy.getRank());
339-
auto order = triton::gpu::getOrder(srcLayout);
340339
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
341340
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
342341
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);

lib/Dialect/Triton/Transforms/LoopUnroll.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,22 @@
2222

2323
namespace mlir::triton {
2424

25-
static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
26-
2725
namespace {
2826

2927
class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3028

3129
int getUnrollFactorOrDefault(scf::ForOp forOp) {
3230
// Use the attribute attached to the loop if it exists otherwise set the
3331
// factor to 1 to suppress the unrolling.
34-
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
35-
mlir::triton::loopUnrollFactorAttrName))
32+
if (auto factor =
33+
forOp->getAttrOfType<IntegerAttr>(loopUnrollFactorAttrName))
3634
return factor.getInt();
3735
return 1;
3836
}
3937

38+
const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
39+
const char *pipelineStagesAttrName = "tt.num_stages";
40+
4041
public:
4142
LoopUnrollPass() = default;
4243
LoopUnrollPass(const LoopUnrollPass &) {}
@@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
4950
loops.push_back(forOp);
5051
});
5152

53+
auto ctx = getOperation()->getContext();
5254
for (auto loop : loops) {
5355
auto unrollFactor = getUnrollFactorOrDefault(loop);
54-
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
56+
loop->removeAttr(loopUnrollFactorAttrName);
5557
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
56-
(void)loopUnrollByFactor(loop, unrollFactor);
58+
auto resultLoops = loopUnrollByFactor(loop, unrollFactor);
59+
// Do not pipeline the epilog loop.
60+
if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) {
61+
(*resultLoops->epilogueLoopOp)
62+
->setAttr(pipelineStagesAttrName,
63+
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
64+
}
5765
}
5866
}
5967
};

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
220220
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
221221
}
222222

223-
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
223+
/* Utility function used by get.*Order methods of SliceEncodingAttr.
224224
* Erase dim and decrease all values larger than dim by 1.
225225
* Example: order = [0, 2, 4, 3, 1], dim = 2
226226
* resOrder = [0, 3, 2, 1]
@@ -265,29 +265,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
265265
}
266266

267267
SmallVector<unsigned> getWarpOrder(Attribute layout) {
268-
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
269-
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
270-
return getWarpOrder(dotLayout.getParent());
271-
}
272-
}
273-
auto order = getOrder(layout);
274-
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
275-
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
276-
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
277-
// `MMAv2.cpp` to match.
278-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
279-
if (mmaLayout.isHopper()) {
280-
// Hopper MMA instructions force warps to be column-major
281-
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
282-
return getMatrixOrder(order.size(), /*rowMajor*/ false);
283-
}
284-
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
285-
// It's quite weird to talk about warp order when that the warps
286-
// are broadcasted along the K dimension
287-
llvm::report_fatal_error(
288-
"DotOperandEncoding::getWarpOrder not implemented");
289-
}
290-
return order;
268+
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
269+
return distributedLayout.getWarpOrder();
270+
else
271+
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
272+
return {};
291273
}
292274

293275
SmallVector<unsigned> getOrder(Attribute layout) {
@@ -296,7 +278,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
296278
}
297279
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
298280
// Order doesn't really matter. We just have to be consistent when unpacking
299-
// the elements in the MMAv2/V3 lowerings. We choose row-major
281+
// the output elements in the LLVM lowerings. We choose row-major
300282
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
301283
auto rank = distributedLayout.getWarpsPerCTA().size();
302284
return getMatrixOrder(rank, /*rowMajor*/ true);
@@ -331,15 +313,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {
331313

332314
llvm::report_fatal_error("Unimplemented usage of getOrder");
333315
return {};
334-
};
316+
}
335317

336318
SmallVector<unsigned> getThreadOrder(Attribute layout) {
337319
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
338320
return distributedLayout.getThreadOrder();
339321
else
340322
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
341323
return {};
342-
};
324+
}
343325

344326
CTALayoutAttr getCTALayout(Attribute layout) {
345327
if (auto distributedLayout =
@@ -782,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
782764
return warpsPerCTA;
783765
}
784766
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
785-
return ::getWarpOrder(*this);
767+
auto parentWarpOrder = ::getWarpOrder(getParent());
768+
return eraseOrder(parentWarpOrder, getDim());
786769
}
787770
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
788771
auto parent = getParent();
@@ -794,7 +777,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
794777
return threadsPerWarp;
795778
}
796779
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
797-
return ::getOrder(*this);
780+
auto parentThreadOrder = ::getThreadOrder(getParent());
781+
return eraseOrder(parentThreadOrder, getDim());
798782
}
799783
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
800784
auto sizePerThread = ::getSizePerThread(getParent());
@@ -1065,7 +1049,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
10651049
return warps;
10661050
}
10671051
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
1068-
return ::getWarpOrder(*this);
1052+
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1053+
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
1054+
return ::getWarpOrder(getParent());
1055+
}
1056+
// It's quite weird to talk about warp order when that the warps
1057+
// are broadcasted along the K dimension
1058+
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
1059+
return {};
10691060
}
10701061
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10711062
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
@@ -1637,7 +1628,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
16371628
return SmallVector<unsigned>(getWarpsPerCTA__());
16381629
}
16391630
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
1640-
return ::getWarpOrder(*this);
1631+
return ::getOrder(*this);
16411632
}
16421633
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
16431634
auto order = ::getOrder(*this);
@@ -1806,7 +1797,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
18061797
return SmallVector<unsigned>(getWarpsPerCTA__());
18071798
}
18081799
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
1809-
return ::getWarpOrder(*this);
1800+
return ::getOrder(*this);
18101801
}
18111802
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
18121803
return ::getOrder(*this);
@@ -1930,7 +1921,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
19301921
return SmallVector<unsigned>(getWarpsPerCTA__());
19311922
}
19321923
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
1933-
return ::getWarpOrder(*this);
1924+
auto rank = getWarpsPerCTA().size();
1925+
// Hopper (wgmma) uses column-major as this is embeded in the instruction
1926+
// For Ampere we can choose either row-major or column-major.
1927+
// We choose row-major as the legacy path did so
1928+
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
19341929
}
19351930
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
19361931
auto rank = getWarpsPerCTA().size();
@@ -1954,10 +1949,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
19541949
"getThreadsPerWarp not implemented for unknown Mma version ");
19551950
}
19561951
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
1957-
return ::getOrder(*this);
1952+
auto rank = getWarpsPerCTA().size();
1953+
return getMatrixOrder(rank, /*rowMajor*/ true);
19581954
}
19591955
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
1960-
auto rank = ::getOrder(*this).size();
1956+
auto rank = getWarpsPerCTA().size();
19611957
SmallVector<unsigned> res(rank, 1);
19621958
if (isAmpere()) {
19631959
res[rank - 2] = 2;
@@ -2198,11 +2194,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
21982194
if (opIdx == 0) {
21992195
sizePerThread[rank - 2] = 2;
22002196
sizePerThread[rank - 1] = 2 * kWidth;
2201-
} else if (opIdx == 1) {
2197+
} else {
2198+
assert(opIdx == 1);
22022199
sizePerThread[rank - 2] = 2 * kWidth;
22032200
sizePerThread[rank - 1] = 1;
2204-
} else {
2205-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
22062201
}
22072202
return sizePerThread;
22082203
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
328328
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
329329
assert(k == 8 || k == 16 || k == 32);
330330

331+
// TODO Make the getOrder of Hopper explicit here via an assert
331332
MLIRContext *ctx = mma.getContext();
332333
LinearLayout ctaLayout(
333334
{{S("register"), {{1, 0}, {0, 8}}},

0 commit comments

Comments
 (0)