Skip to content

Commit 38a11b8

Browse files
authored
[BACKEND] Fix uses of getOrder(DotOperand(Nvidia) and MMA(Nvidia)) (#5055)
We use `getOrder` very liberally throughout the codebase, when we really meant to use `getThreadOrder`. This is an issue with the input layout is an `DotOperand(mma(opIdx=1))`, where the thread order and the matrix order are opposite. Found this to be an issue when a PR changed the `getOrder` of `DotOperand(Hopper)` to an incorrect one and CI still passed! The issue here is that the LLVM lowering for wgmma and the LinearLayout does not use `getOrder`, but there are many other subsystems do, and many heuristics would be getting an incorrect order, and potentially be disabled. This is particularly problematic for `DotOperand(opIdx=1)` in nvidia hardware, as `getThreadOrder` and `getOrder` are different! While doing so we: - Audit most (all?) the calls to `getOrder(dotOperand)`. It turns out that most of them really meant `getThreadOrder` - Fix the ordering methods of `SliceEncodingAttr` to be consistent - Move the implementation of `getWarpOrder` to the Attr classes, because of OOP The test strategy was to add `llvm::report_fatal_error("Testing");` within `getOrder(nvidiaMma)` and `getOrder(DotOperand(nvidiaMma))` and triaging all errors that were raised in CI.
1 parent d2b8659 commit 38a11b8

File tree

13 files changed

+63
-65
lines changed

13 files changed

+63
-65
lines changed

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ReduceOpHelper {
6666
// The shape of the shared memory space needed for the reduction.
6767
SmallVector<unsigned> getScratchRepShape();
6868

69-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
69+
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
7070

7171
unsigned getScratchSizeInBytes();
7272

lib/Analysis/Allocation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4646
auto dstShapePerCTATile =
4747
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
4848

49+
assert(srcTy.getRank() == dstTy.getRank() &&
50+
"src and dst must have the same rank");
51+
4952
unsigned rank = dstTy.getRank();
5053
SmallVector<unsigned> repShape(rank);
5154
for (unsigned d = 0; d < rank; ++d) {

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12131213

12141214
// Here order should be ordered by contiguous first, so the first element
12151215
// should have the largest contiguous.
1216-
auto order = triton::gpu::getOrder(layout);
1216+
auto order = triton::gpu::getThreadOrder(layout);
12171217
unsigned align = getPtrAlignment(ptr);
12181218

12191219
auto uniqueContigPerThread =
@@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12351235
if (!axisInfo)
12361236
return 1;
12371237
auto layout = tensorTy.getEncoding();
1238-
auto order = triton::gpu::getOrder(layout);
1238+
auto order = triton::gpu::getThreadOrder(layout);
12391239
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12401240
auto maxContig = axisInfo->getContiguity(order[0]);
12411241
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12621262
auto *axisInfo = getAxisInfo(mask);
12631263
if (!axisInfo)
12641264
return 1;
1265-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1265+
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
12661266
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12671267
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12681268
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
3232
return axis;
3333
}
3434

35-
SmallVector<unsigned> getParentOrder(Attribute layout) {
35+
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
3636
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
37-
return getParentOrder(sliceEncoding.getParent());
37+
return getParentThreadOrder(sliceEncoding.getParent());
3838
}
3939
return getThreadOrder(layout);
4040
}
@@ -44,12 +44,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
4444
// TODO(jlebar): Move this class into namespace triton.
4545
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
4646
return getParentAxis(getSrcLayout(), axis) ==
47-
getParentOrder(getSrcLayout())[0];
47+
getParentThreadOrder(getSrcLayout())[0];
4848
}
4949

50-
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
50+
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
5151
auto srcLayout = getSrcLayout();
52-
auto order = getOrder(srcLayout);
52+
auto order = getThreadOrder(srcLayout);
5353
auto it = std::find(order.begin(), order.end(), axis);
5454
// delete the axis from order
5555
order.erase(it);

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ struct ReduceOpConversion
283283
getMultiDimWarpId(helper, warpId, loc, rewriter);
284284
Value warpIdAxis = multiDimWarpId[axis];
285285

286-
auto smemOrder = helper.getOrderWithAxisAtBeginning();
286+
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
287287
for (auto it : accs) {
288288
const SmallVector<unsigned> &key = it.first;
289289
SmallVector<Value> &acc = it.second;
@@ -370,7 +370,7 @@ struct ReduceOpConversion
370370
Location loc = op.getLoc();
371371
auto srcLayout = helper.getSrcLayout();
372372
auto axis = op.getAxis();
373-
auto smemOrder = helper.getOrderWithAxisAtBeginning();
373+
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
374374
SmallVector<Value> results(op.getNumOperands());
375375
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
376376
auto elemTy = getElementType(op, i);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
181181
int numContiguousValues = 1;
182182
auto encoding = cast<BlockedEncodingAttr>(
183183
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
184-
int splitDim = encoding.getOrder().size() - 1;
185-
for (int i = 0; i < encoding.getOrder().size(); i++) {
186-
if (encoding.getOrder()[i] == splitDim)
184+
int splitDim = encoding.getThreadOrder().size() - 1;
185+
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
186+
if (encoding.getThreadOrder()[i] == splitDim)
187187
break;
188188
numContiguousValues *= encoding.getSizePerThread()[i];
189189
}
@@ -336,7 +336,6 @@ struct BroadcastOpConversion
336336
unsigned rank = srcTy.getRank();
337337
auto typeConverter = getTypeConverter();
338338
assert(rank == resultTy.getRank());
339-
auto order = triton::gpu::getOrder(srcLayout);
340339
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
341340
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
342341
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
217217
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
218218
}
219219

220-
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
220+
/* Utility function used by get.*Order methods of SliceEncodingAttr.
221221
* Erase dim and decrease all values larger than dim by 1.
222222
* Example: order = [0, 2, 4, 3, 1], dim = 2
223223
* resOrder = [0, 3, 2, 1]
@@ -262,29 +262,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
262262
}
263263

264264
SmallVector<unsigned> getWarpOrder(Attribute layout) {
265-
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
266-
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
267-
return getWarpOrder(dotLayout.getParent());
268-
}
269-
}
270-
auto order = getOrder(layout);
271-
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
272-
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
273-
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
274-
// `MMAv2.cpp` to match.
275-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
276-
if (mmaLayout.isHopper()) {
277-
// Hopper MMA instructions force warps to be column-major
278-
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
279-
return getMatrixOrder(order.size(), /*rowMajor*/ false);
280-
}
281-
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
282-
// It's quite weird to talk about warp order when that the warps
283-
// are broadcasted along the K dimension
284-
llvm::report_fatal_error(
285-
"DotOperandEncoding::getWarpOrder not implemented");
286-
}
287-
return order;
265+
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
266+
return distributedLayout.getWarpOrder();
267+
else
268+
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
269+
return {};
288270
}
289271

290272
SmallVector<unsigned> getOrder(Attribute layout) {
@@ -293,7 +275,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
293275
}
294276
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
295277
// Order doesn't really matter. We just have to be consistent when unpacking
296-
// the elements in the MMAv2/V3 lowerings. We choose row-major
278+
// the output elements in the LLVM lowerings. We choose row-major
297279
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
298280
auto rank = distributedLayout.getWarpsPerCTA().size();
299281
return getMatrixOrder(rank, /*rowMajor*/ true);
@@ -318,15 +300,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {
318300

319301
llvm::report_fatal_error("Unimplemented usage of getOrder");
320302
return {};
321-
};
303+
}
322304

323305
SmallVector<unsigned> getThreadOrder(Attribute layout) {
324306
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
325307
return distributedLayout.getThreadOrder();
326308
else
327309
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
328310
return {};
329-
};
311+
}
330312

331313
CTALayoutAttr getCTALayout(Attribute layout) {
332314
if (auto distributedLayout =
@@ -769,7 +751,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
769751
return warpsPerCTA;
770752
}
771753
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
772-
return ::getWarpOrder(*this);
754+
auto parentWarpOrder = ::getWarpOrder(getParent());
755+
return eraseOrder(parentWarpOrder, getDim());
773756
}
774757
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
775758
auto parent = getParent();
@@ -781,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
781764
return threadsPerWarp;
782765
}
783766
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
784-
return ::getOrder(*this);
767+
auto parentThreadOrder = ::getThreadOrder(getParent());
768+
return eraseOrder(parentThreadOrder, getDim());
785769
}
786770
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
787771
auto sizePerThread = ::getSizePerThread(getParent());
@@ -1049,7 +1033,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
10491033
return warps;
10501034
}
10511035
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
1052-
return ::getWarpOrder(*this);
1036+
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1037+
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
1038+
return ::getWarpOrder(getParent());
1039+
}
1040+
// It's quite weird to talk about warp order when that the warps
1041+
// are broadcasted along the K dimension
1042+
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
1043+
return {};
10531044
}
10541045
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10551046
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
@@ -1597,7 +1588,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
15971588
return SmallVector<unsigned>(getWarpsPerCTA__());
15981589
}
15991590
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
1600-
return ::getWarpOrder(*this);
1591+
return ::getOrder(*this);
16011592
}
16021593
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
16031594
auto order = ::getOrder(*this);
@@ -1766,7 +1757,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
17661757
return SmallVector<unsigned>(getWarpsPerCTA__());
17671758
}
17681759
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
1769-
return ::getWarpOrder(*this);
1760+
return ::getOrder(*this);
17701761
}
17711762
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
17721763
return ::getOrder(*this);
@@ -1890,7 +1881,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
18901881
return SmallVector<unsigned>(getWarpsPerCTA__());
18911882
}
18921883
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
1893-
return ::getWarpOrder(*this);
1884+
auto rank = getWarpsPerCTA().size();
1885+
// Hopper (wgmma) uses column-major as this is embeded in the instruction
1886+
// For Ampere we can choose either row-major or column-major.
1887+
// We choose row-major as the legacy path did so
1888+
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
18941889
}
18951890
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
18961891
auto rank = getWarpsPerCTA().size();
@@ -1914,10 +1909,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
19141909
"getThreadsPerWarp not implemented for unknown Mma version ");
19151910
}
19161911
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
1917-
return ::getOrder(*this);
1912+
auto rank = getWarpsPerCTA().size();
1913+
return getMatrixOrder(rank, /*rowMajor*/ true);
19181914
}
19191915
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
1920-
auto rank = ::getOrder(*this).size();
1916+
auto rank = getWarpsPerCTA().size();
19211917
SmallVector<unsigned> res(rank, 1);
19221918
if (isAmpere()) {
19231919
res[rank - 2] = 2;
@@ -2158,11 +2154,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
21582154
if (opIdx == 0) {
21592155
sizePerThread[rank - 2] = 2;
21602156
sizePerThread[rank - 1] = 2 * kWidth;
2161-
} else if (opIdx == 1) {
2157+
} else {
2158+
assert(opIdx == 1);
21622159
sizePerThread[rank - 2] = 2 * kWidth;
21632160
sizePerThread[rank - 1] = 1;
2164-
} else {
2165-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
21662161
}
21672162
return sizePerThread;
21682163
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
327327
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
328328
assert(k == 8 || k == 16 || k == 32);
329329

330+
// TODO Make the getOrder of Hopper explicit here via an assert
330331
MLIRContext *ctx = mma.getContext();
331332
LinearLayout ctaLayout(
332333
{{S("register"), {{1, 0}, {0, 8}}},

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
4444
return;
4545
if (!cvtNeedsSharedMemory(srcType, dstType))
4646
return;
47-
auto srcOrder = triton::gpu::getOrder(srcEncoding);
48-
auto rank = srcOrder.size();
47+
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
48+
auto rank = srcThreadOrder.size();
4949
SmallVector<unsigned> sharedOrder;
5050
if (rank == 3) {
5151
// add all elements except the element that is zero
5252
for (unsigned i = 0; i < rank; ++i)
53-
if (srcOrder[i] != 0)
54-
sharedOrder.emplace_back(srcOrder[i]);
53+
if (srcThreadOrder[i] != 0)
54+
sharedOrder.emplace_back(srcThreadOrder[i]);
5555
sharedOrder.emplace_back(0);
5656
} else {
57-
sharedOrder = srcOrder;
57+
sharedOrder = srcThreadOrder;
5858
}
5959
auto sharedMemorySpace =
6060
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());

python/test/unit/language/test_core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,6 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.
17651765
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field)
17661766

17671767
if constant_field == "value":
1768-
print(output, ref)
17691768
assert torch.all(output == ref)
17701769
else:
17711770
assert torch.all(output == 0)

0 commit comments

Comments
 (0)