Skip to content

Commit 016ee00

Browse files
Revert "[LAYOUTS] Unify the implementation of getShapePerCTA (#5183)"
This reverts commit d36e579.
1 parent 6ccad72 commit 016ee00

File tree

12 files changed

+153
-29
lines changed

12 files changed

+153
-29
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
116116
* (3) In the implementation of emitIndices, ShapePerCTATile will
117117
* be replicated or wrapped to fit ShapePerCTA.
118118
*/
119-
SmallVector<unsigned> getShapePerCTATile(Attribute layout);
119+
SmallVector<unsigned>
120+
getShapePerCTATile(Attribute layout,
121+
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
120122

121123
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
122124
ArrayRef<int64_t> shape);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,11 @@ We call each individual tile "rep".
502502
"SmallVector<unsigned>",
503503
"getCTASplitNum">,
504504

505+
InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA",
506+
"SmallVector<unsigned>",
507+
"getShapePerCTATile",
508+
(ins "ArrayRef<int64_t>":$tensorShape)>,
509+
505510
InterfaceMethod<"Gets the number of contiguous elements per thread.",
506511
"SmallVector<unsigned>",
507512
"getContigPerThread">,
@@ -560,6 +565,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
560565
SmallVector<unsigned> getThreadOrder() const;
561566

562567
SmallVector<unsigned> getSizePerThread() const;
568+
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;
563569

564570
std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
565571
}];
@@ -759,6 +765,13 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
759765
"bool",
760766
"supportReduction">,
761767

768+
InterfaceMethod<"Return shape per CTA.",
769+
"SmallVector<unsigned>",
770+
"getShapePerCTATileForOperand",
771+
(ins "ArrayRef<int64_t>":$tensorShape,
772+
"int":$kWidth,
773+
"int":$opIdx)>,
774+
762775
InterfaceMethod<"Return size per thread for dot operands.",
763776
"SmallVector<unsigned>",
764777
"getSizePerThreadForOperand",
@@ -892,6 +905,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
892905
return true;
893906
}
894907
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
908+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
895909
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
896910
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
897911
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
@@ -1003,6 +1017,7 @@ Row | warp 0 warp 2
10031017
return true;
10041018
}
10051019
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1020+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
10061021
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10071022
SmallVector<int64_t> getElemsPerInstrForOperands() const;
10081023
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
@@ -1138,6 +1153,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11381153
return false;
11391154
};
11401155
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1156+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
11411157

11421158
SmallVector<unsigned> getContigPerThread() {
11431159
assert(isAmpere() || isHopper());

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4141

4242
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
4343
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
44-
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
45-
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);
44+
auto srcShapePerCTATile =
45+
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
46+
auto dstShapePerCTATile =
47+
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
4648

4749
assert(srcTy.getRank() == dstTy.getRank() &&
4850
"src and dst must have the same rank");

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion
174174
SmallVector<unsigned> outNumCTAsEachRep(rank);
175175
SmallVector<unsigned> inNumCTAs(rank);
176176
SmallVector<unsigned> outNumCTAs(rank);
177-
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
178-
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
177+
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
178+
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
179179
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
180180

181181
for (unsigned d = 0; d < rank; ++d) {

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ struct ReduceOpConversion
421421
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
422422
resultLayout, resultTy, true);
423423
auto resultShape = resultTy.getShape();
424-
auto resultCTATile = getShapePerCTATile(resultLayout);
424+
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
425425
assert(resultIndices.size() == resultElems);
426426

427427
SmallVector<Value> resultVals(resultElems);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 119 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -204,25 +204,12 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
204204
}
205205
return ret;
206206
}
207-
SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
207+
208+
SmallVector<unsigned> getShapePerCTATile(Attribute layout,
209+
ArrayRef<int64_t> tensorShape) {
208210
if (auto distributedLayout =
209211
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210-
auto sizePerThread = distributedLayout.getSizePerThread();
211-
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
212-
// ThreadsPerWarp does not align with this function for slice layout
213-
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214-
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
215-
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
216-
}
217-
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
218-
assert(sizePerThread.size() == threadsPerWarp.size() &&
219-
sizePerThread.size() == warpsPerCTA.size());
220-
SmallVector<unsigned> shape;
221-
for (auto [size, thread, warp] :
222-
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
223-
shape.push_back(size * thread * warp);
224-
}
225-
return shape;
212+
return distributedLayout.getShapePerCTATile(tensorShape);
226213
} else {
227214
llvm::report_fatal_error("getShapePerCTATile not implemented");
228215
return SmallVector<unsigned>();
@@ -704,6 +691,14 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
704691
SmallVector<unsigned> BlockedEncodingAttr::getSizePerThread() const {
705692
return SmallVector<unsigned>(getSizePerThread__());
706693
}
694+
SmallVector<unsigned>
695+
BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
696+
SmallVector<unsigned> shape;
697+
for (unsigned d = 0, n = getOrder().size(); d < n; ++d)
698+
shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] *
699+
getWarpsPerCTA()[d]);
700+
return shape;
701+
}
707702

708703
template <class T>
709704
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
@@ -805,6 +800,12 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
805800
sizePerThread.erase(sizePerThread.begin() + getDim());
806801
return sizePerThread;
807802
}
803+
SmallVector<unsigned>
804+
SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
805+
SmallVector<unsigned> shape = ::getShapePerCTATile(getParent(), tensorShape);
806+
shape.erase(shape.begin() + getDim());
807+
return shape;
808+
}
808809

809810
//
810811

@@ -998,9 +999,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
998999
}
9991000
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
10001001
auto shapePerCTA = getShapePerCTA(*this, shape);
1001-
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
1002+
auto shapePerCTATile = ::getShapePerCTATile(blockedLayout);
10021003
auto order = blockedLayout.getOrder();
1003-
auto sizePerThread = blockedLayout.getSizePerThread();
1004+
auto sizePerThread = ::getSizePerThread(blockedLayout);
10041005

10051006
int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0];
10061007
int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0];
@@ -1071,6 +1072,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10711072
/*kMajor*/ true);
10721073
}
10731074
}
1075+
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
1076+
ArrayRef<int64_t> tensorShape) const {
1077+
auto parentLayout = getParent();
1078+
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1079+
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1080+
return parentMmaLayout.getShapePerCTATileForOperand(
1081+
tensorShape, getKWidth(), getOpIdx());
1082+
} else {
1083+
llvm::report_fatal_error(
1084+
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1085+
"supported yet");
1086+
}
1087+
}
10741088

10751089
LogicalResult DotOperandEncodingAttr::verify(
10761090
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1592,6 +1606,16 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
15921606
//===----------------------------------------------------------------------===//
15931607
// TODO: there is a lot of common code with MmaEncoding here
15941608

1609+
SmallVector<unsigned>
1610+
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1611+
auto warpsPerCTA = getWarpsPerCTA();
1612+
auto rank = warpsPerCTA.size();
1613+
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1614+
shapePerCTATile[rank - 1] *= getMDim();
1615+
shapePerCTATile[rank - 2] *= getNDim();
1616+
return shapePerCTATile;
1617+
}
1618+
15951619
SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAsPerCGA() const {
15961620
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
15971621
}
@@ -1735,10 +1759,43 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17351759
return sizePerThread;
17361760
}
17371761

1762+
SmallVector<unsigned>
1763+
AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1764+
int kWidth, int opIdx) const {
1765+
assert(getMDim() == 32 || getMDim() == 16);
1766+
auto parentShapePerCTATile = getShapePerCTATile(shape);
1767+
auto rank = parentShapePerCTATile.size();
1768+
if (opIdx == 0) {
1769+
if (rank == 2)
1770+
return {parentShapePerCTATile[rank - 2], 32};
1771+
else
1772+
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32};
1773+
} else if (opIdx == 1) {
1774+
if (rank == 2)
1775+
return {32, parentShapePerCTATile[rank - 1]};
1776+
else
1777+
return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]};
1778+
} else {
1779+
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1780+
}
1781+
llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1");
1782+
}
1783+
17381784
//===----------------------------------------------------------------------===//
17391785
// Wmma encoding
17401786
//===----------------------------------------------------------------------===//
17411787

1788+
SmallVector<unsigned>
1789+
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1790+
auto warpsPerCTA = getWarpsPerCTA();
1791+
auto rank = warpsPerCTA.size();
1792+
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1793+
1794+
auto mnkDim = getMNKDimPerInstr();
1795+
shapePerCTATile[rank - 2] *= mnkDim[0];
1796+
shapePerCTATile[rank - 1] *= mnkDim[1];
1797+
return shapePerCTATile;
1798+
}
17421799
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
17431800
auto rank = getWarpsPerCTA().size();
17441801
return getMatrixOrder(rank, /*rowMajor*/ true);
@@ -1803,6 +1860,21 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18031860
return sizePerThread;
18041861
}
18051862

1863+
SmallVector<unsigned>
1864+
AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1865+
int kWidth, int opIdx) const {
1866+
auto parentShapePerCTA = getShapePerCTATile(shape);
1867+
auto rank = shape.size();
1868+
assert(rank == 2);
1869+
if (opIdx == 0) {
1870+
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
1871+
} else if (opIdx == 1) {
1872+
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
1873+
} else {
1874+
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1875+
}
1876+
}
1877+
18061878
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
18071879
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
18081880
auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx);
@@ -1921,6 +1993,24 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19211993
llvm_unreachable("Unexpected mma version");
19221994
}
19231995

1996+
SmallVector<unsigned>
1997+
NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1998+
if (isAmpere()) {
1999+
auto warpsPerCTA = getWarpsPerCTA();
2000+
auto rank = warpsPerCTA.size();
2001+
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(),
2002+
warpsPerCTA.end());
2003+
shapePerCTATile[rank - 1] *= 8;
2004+
shapePerCTATile[rank - 2] *= 16;
2005+
return shapePerCTATile;
2006+
}
2007+
if (isHopper()) {
2008+
auto instrShape = getInstrShape();
2009+
return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]};
2010+
}
2011+
llvm::report_fatal_error("Unexpected MMA layout version found");
2012+
}
2013+
19242014
SmallVector<unsigned>
19252015
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
19262016
auto rank = getWarpsPerCTA().size();
@@ -1961,6 +2051,16 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19612051
}
19622052
}
19632053

2054+
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
2055+
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
2056+
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
2057+
auto shapePerCTATile = getShapePerCTATile(shape);
2058+
auto rank = shapePerCTATile.size();
2059+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
2060+
// 4 threads * 2 subtiles
2061+
shapePerCTATile[kDim] = kWidth * 2 * 4;
2062+
return shapePerCTATile;
2063+
}
19642064
SmallVector<unsigned>
19652065
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
19662066
auto rank = getWarpsPerCTA().size();

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ LogicalResult ExtractSliceOp::verify() {
7878
}
7979

8080
auto srcShape = srcTy.getShape();
81-
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
81+
auto shapePerCTATile =
82+
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
8283
shapePerCTATile[0] =
8384
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
8485
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct ExtractSliceOpConversion
7070
auto order = triton::gpu::getOrder(srcLayout);
7171

7272
// Calculate valid total number of workers in each dimension
73-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
73+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
7474
shapePerCTATile[0] =
7575
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
7676
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
4848
} else {
4949
warpOrder = triton::gpu::getWarpOrder(layout);
5050
}
51-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
51+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
5252
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
5353
Value laneId = urem(tid, warpSize);
5454
Value warpId = udiv(tid, warpSize);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using ::mlir::LLVM::linearize;
1717
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1818
using ::mlir::triton::gpu::getOrder;
1919
using ::mlir::triton::gpu::getShapePerCTA;
20+
using ::mlir::triton::gpu::getShapePerCTATile;
2021
using ::mlir::triton::gpu::getSizePerThread;
2122
using ::mlir::triton::gpu::getTotalElemsPerThread;
2223
using ::mlir::triton::gpu::SharedEncodingAttr;

0 commit comments

Comments
 (0)