Skip to content

Commit d36e579

Browse files
authored
[LAYOUTS] Unify the implementation of getShapePerCTA (#5183)
We unify it and simplify its API (it was taking an unused `shape` parameter). While doing this, we found that the previous implementation was incorrect at least for `AMDWmmaEncodingAttr`, as this layout was using the shape parameter. Interestingly enough the doc in the header file for this function noted that the function is indeed independent of the tensor shape, even though the function does take a shape as an input! https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/include/triton/Dialect/TritonGPU/IR/Dialect.h#L113-L114
1 parent 7088c64 commit d36e579

File tree

12 files changed

+29
-153
lines changed

12 files changed

+29
-153
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
116116
* (3) In the implementation of emitIndices, ShapePerCTATile will
117117
* be replicated or wrapped to fit ShapePerCTA.
118118
*/
119-
SmallVector<unsigned>
120-
getShapePerCTATile(Attribute layout,
121-
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
119+
SmallVector<unsigned> getShapePerCTATile(Attribute layout);
122120

123121
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
124122
ArrayRef<int64_t> shape);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,6 @@ We call each individual tile "rep".
502502
"SmallVector<unsigned>",
503503
"getCTASplitNum">,
504504

505-
InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA",
506-
"SmallVector<unsigned>",
507-
"getShapePerCTATile",
508-
(ins "ArrayRef<int64_t>":$tensorShape)>,
509-
510505
InterfaceMethod<"Gets the number of contiguous elements per thread.",
511506
"SmallVector<unsigned>",
512507
"getContigPerThread">,
@@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
565560
SmallVector<unsigned> getThreadOrder() const;
566561

567562
SmallVector<unsigned> getSizePerThread() const;
568-
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;
569563

570564
std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
571565
}];
@@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
765759
"bool",
766760
"supportReduction">,
767761

768-
InterfaceMethod<"Return shape per CTA.",
769-
"SmallVector<unsigned>",
770-
"getShapePerCTATileForOperand",
771-
(ins "ArrayRef<int64_t>":$tensorShape,
772-
"int":$kWidth,
773-
"int":$opIdx)>,
774-
775762
InterfaceMethod<"Return size per thread for dot operands.",
776763
"SmallVector<unsigned>",
777764
"getSizePerThreadForOperand",
@@ -900,7 +887,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
900887
return true;
901888
}
902889
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
903-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
904890
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
905891
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
906892
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
@@ -1008,7 +994,6 @@ Row | warp 0 warp 2
1008994
return true;
1009995
}
1010996
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1011-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1012997
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1013998
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1014999
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
@@ -1140,7 +1125,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11401125
return false;
11411126
};
11421127
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1143-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
11441128

11451129
SmallVector<unsigned> getContigPerThread() {
11461130
assert(isAmpere() || isHopper());

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4141

4242
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
4343
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
44-
auto srcShapePerCTATile =
45-
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
46-
auto dstShapePerCTATile =
47-
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
44+
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
45+
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);
4846

4947
assert(srcTy.getRank() == dstTy.getRank() &&
5048
"src and dst must have the same rank");

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion
174174
SmallVector<unsigned> outNumCTAsEachRep(rank);
175175
SmallVector<unsigned> inNumCTAs(rank);
176176
SmallVector<unsigned> outNumCTAs(rank);
177-
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
178-
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
177+
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
178+
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
179179
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
180180

181181
for (unsigned d = 0; d < rank; ++d) {

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ struct ReduceOpConversion
421421
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
422422
resultLayout, resultTy, true);
423423
auto resultShape = resultTy.getShape();
424-
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
424+
auto resultCTATile = getShapePerCTATile(resultLayout);
425425
assert(resultIndices.size() == resultElems);
426426

427427
SmallVector<Value> resultVals(resultElems);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 19 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,25 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
201201
}
202202
return ret;
203203
}
204-
205-
SmallVector<unsigned> getShapePerCTATile(Attribute layout,
206-
ArrayRef<int64_t> tensorShape) {
204+
SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
207205
if (auto distributedLayout =
208206
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
209-
return distributedLayout.getShapePerCTATile(tensorShape);
207+
auto sizePerThread = distributedLayout.getSizePerThread();
208+
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
209+
// ThreadsPerWarp does not align with this function for slice layout
210+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
211+
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
212+
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
213+
}
214+
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
215+
assert(sizePerThread.size() == threadsPerWarp.size() &&
216+
sizePerThread.size() == warpsPerCTA.size());
217+
SmallVector<unsigned> shape;
218+
for (auto [size, thread, warp] :
219+
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
220+
shape.push_back(size * thread * warp);
221+
}
222+
return shape;
210223
} else {
211224
llvm::report_fatal_error("getShapePerCTATile not implemented");
212225
return SmallVector<unsigned>();
@@ -678,14 +691,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
678691
SmallVector<unsigned> BlockedEncodingAttr::getSizePerThread() const {
679692
return SmallVector<unsigned>(getSizePerThread__());
680693
}
681-
SmallVector<unsigned>
682-
BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
683-
SmallVector<unsigned> shape;
684-
for (unsigned d = 0, n = getOrder().size(); d < n; ++d)
685-
shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] *
686-
getWarpsPerCTA()[d]);
687-
return shape;
688-
}
689694

690695
template <class T>
691696
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
@@ -787,12 +792,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
787792
sizePerThread.erase(sizePerThread.begin() + getDim());
788793
return sizePerThread;
789794
}
790-
SmallVector<unsigned>
791-
SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
792-
SmallVector<unsigned> shape = ::getShapePerCTATile(getParent(), tensorShape);
793-
shape.erase(shape.begin() + getDim());
794-
return shape;
795-
}
796795

797796
//
798797

@@ -979,9 +978,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
979978
}
980979
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
981980
auto shapePerCTA = getShapePerCTA(*this, shape);
982-
auto shapePerCTATile = ::getShapePerCTATile(blockedLayout);
981+
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
983982
auto order = blockedLayout.getOrder();
984-
auto sizePerThread = ::getSizePerThread(blockedLayout);
983+
auto sizePerThread = blockedLayout.getSizePerThread();
985984

986985
int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0];
987986
int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0];
@@ -1043,19 +1042,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10431042
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
10441043
/*kMajor*/ true);
10451044
}
1046-
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
1047-
ArrayRef<int64_t> tensorShape) const {
1048-
auto parentLayout = getParent();
1049-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1050-
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1051-
return parentMmaLayout.getShapePerCTATileForOperand(
1052-
tensorShape, getKWidth(), getOpIdx());
1053-
} else {
1054-
llvm::report_fatal_error(
1055-
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1056-
"supported yet");
1057-
}
1058-
}
10591045

10601046
LogicalResult DotOperandEncodingAttr::verify(
10611047
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1562,16 +1548,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
15621548
//===----------------------------------------------------------------------===//
15631549
// TODO: there is a lot of common code with MmaEncoding here
15641550

1565-
SmallVector<unsigned>
1566-
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1567-
auto warpsPerCTA = getWarpsPerCTA();
1568-
auto rank = warpsPerCTA.size();
1569-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1570-
shapePerCTATile[rank - 1] *= getMDim();
1571-
shapePerCTATile[rank - 2] *= getNDim();
1572-
return shapePerCTATile;
1573-
}
1574-
15751551
SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAsPerCGA() const {
15761552
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
15771553
}
@@ -1715,43 +1691,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17151691
return sizePerThread;
17161692
}
17171693

1718-
SmallVector<unsigned>
1719-
AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1720-
int kWidth, int opIdx) const {
1721-
assert(getMDim() == 32 || getMDim() == 16);
1722-
auto parentShapePerCTATile = getShapePerCTATile(shape);
1723-
auto rank = parentShapePerCTATile.size();
1724-
if (opIdx == 0) {
1725-
if (rank == 2)
1726-
return {parentShapePerCTATile[rank - 2], 32};
1727-
else
1728-
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32};
1729-
} else if (opIdx == 1) {
1730-
if (rank == 2)
1731-
return {32, parentShapePerCTATile[rank - 1]};
1732-
else
1733-
return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]};
1734-
} else {
1735-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1736-
}
1737-
llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1");
1738-
}
1739-
17401694
//===----------------------------------------------------------------------===//
17411695
// Wmma encoding
17421696
//===----------------------------------------------------------------------===//
17431697

1744-
SmallVector<unsigned>
1745-
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1746-
auto warpsPerCTA = getWarpsPerCTA();
1747-
auto rank = warpsPerCTA.size();
1748-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1749-
1750-
auto mnkDim = getMNKDimPerInstr();
1751-
shapePerCTATile[rank - 2] *= mnkDim[0];
1752-
shapePerCTATile[rank - 1] *= mnkDim[1];
1753-
return shapePerCTATile;
1754-
}
17551698
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
17561699
auto rank = getWarpsPerCTA().size();
17571700
return getMatrixOrder(rank, /*rowMajor*/ true);
@@ -1816,21 +1759,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18161759
return sizePerThread;
18171760
}
18181761

1819-
SmallVector<unsigned>
1820-
AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1821-
int kWidth, int opIdx) const {
1822-
auto parentShapePerCTA = getShapePerCTATile(shape);
1823-
auto rank = shape.size();
1824-
assert(rank == 2);
1825-
if (opIdx == 0) {
1826-
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
1827-
} else if (opIdx == 1) {
1828-
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
1829-
} else {
1830-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1831-
}
1832-
}
1833-
18341762
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
18351763
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
18361764
auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx);
@@ -1949,24 +1877,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19491877
llvm_unreachable("Unexpected mma version");
19501878
}
19511879

1952-
SmallVector<unsigned>
1953-
NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1954-
if (isAmpere()) {
1955-
auto warpsPerCTA = getWarpsPerCTA();
1956-
auto rank = warpsPerCTA.size();
1957-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(),
1958-
warpsPerCTA.end());
1959-
shapePerCTATile[rank - 1] *= 8;
1960-
shapePerCTATile[rank - 2] *= 16;
1961-
return shapePerCTATile;
1962-
}
1963-
if (isHopper()) {
1964-
auto instrShape = getInstrShape();
1965-
return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]};
1966-
}
1967-
llvm::report_fatal_error("Unexpected MMA layout version found");
1968-
}
1969-
19701880
SmallVector<unsigned>
19711881
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
19721882
auto rank = getWarpsPerCTA().size();
@@ -2007,16 +1917,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20071917
}
20081918
}
20091919

2010-
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
2011-
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
2012-
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
2013-
auto shapePerCTATile = getShapePerCTATile(shape);
2014-
auto rank = shapePerCTATile.size();
2015-
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
2016-
// 4 threads * 2 subtiles
2017-
shapePerCTATile[kDim] = kWidth * 2 * 4;
2018-
return shapePerCTATile;
2019-
}
20201920
SmallVector<unsigned>
20211921
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
20221922
auto rank = getWarpsPerCTA().size();

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() {
7878
}
7979

8080
auto srcShape = srcTy.getShape();
81-
auto shapePerCTATile =
82-
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
81+
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
8382
shapePerCTATile[0] =
8483
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
8584
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct ExtractSliceOpConversion
7070
auto order = triton::gpu::getOrder(srcLayout);
7171

7272
// Calculate valid total number of workers in each dimension
73-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
73+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
7474
shapePerCTATile[0] =
7575
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
7676
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
4848
} else {
4949
warpOrder = triton::gpu::getWarpOrder(layout);
5050
}
51-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
51+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
5252
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
5353
Value laneId = urem(tid, warpSize);
5454
Value warpId = udiv(tid, warpSize);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize;
1717
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1818
using ::mlir::triton::gpu::getOrder;
1919
using ::mlir::triton::gpu::getShapePerCTA;
20-
using ::mlir::triton::gpu::getShapePerCTATile;
2120
using ::mlir::triton::gpu::getSizePerThread;
2221
using ::mlir::triton::gpu::getTotalElemsPerThread;
2322
using ::mlir::triton::gpu::SharedEncodingAttr;

0 commit comments

Comments
 (0)