Skip to content

Commit 6ccad72

Browse files
Merge commit 'd36e57998e7aaf7495c5d916b51948b2a42ae5f3'
2 parents fd401e7 + d36e579 commit 6ccad72

File tree

12 files changed

+29
-153
lines changed

12 files changed

+29
-153
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
116116
* (3) In the implementation of emitIndices, ShapePerCTATile will
117117
* be replicated or wrapped to fit ShapePerCTA.
118118
*/
119-
SmallVector<unsigned>
120-
getShapePerCTATile(Attribute layout,
121-
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
119+
SmallVector<unsigned> getShapePerCTATile(Attribute layout);
122120

123121
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
124122
ArrayRef<int64_t> shape);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,6 @@ We call each individual tile "rep".
502502
"SmallVector<unsigned>",
503503
"getCTASplitNum">,
504504

505-
InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA",
506-
"SmallVector<unsigned>",
507-
"getShapePerCTATile",
508-
(ins "ArrayRef<int64_t>":$tensorShape)>,
509-
510505
InterfaceMethod<"Gets the number of contiguous elements per thread.",
511506
"SmallVector<unsigned>",
512507
"getContigPerThread">,
@@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
565560
SmallVector<unsigned> getThreadOrder() const;
566561

567562
SmallVector<unsigned> getSizePerThread() const;
568-
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;
569563

570564
std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
571565
}];
@@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
765759
"bool",
766760
"supportReduction">,
767761

768-
InterfaceMethod<"Return shape per CTA.",
769-
"SmallVector<unsigned>",
770-
"getShapePerCTATileForOperand",
771-
(ins "ArrayRef<int64_t>":$tensorShape,
772-
"int":$kWidth,
773-
"int":$opIdx)>,
774-
775762
InterfaceMethod<"Return size per thread for dot operands.",
776763
"SmallVector<unsigned>",
777764
"getSizePerThreadForOperand",
@@ -905,7 +892,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
905892
return true;
906893
}
907894
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
908-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
909895
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
910896
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
911897
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
@@ -1017,7 +1003,6 @@ Row | warp 0 warp 2
10171003
return true;
10181004
}
10191005
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1020-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
10211006
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10221007
SmallVector<int64_t> getElemsPerInstrForOperands() const;
10231008
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
@@ -1153,7 +1138,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11531138
return false;
11541139
};
11551140
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1156-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
11571141

11581142
SmallVector<unsigned> getContigPerThread() {
11591143
assert(isAmpere() || isHopper());

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4141

4242
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
4343
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
44-
auto srcShapePerCTATile =
45-
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
46-
auto dstShapePerCTATile =
47-
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
44+
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
45+
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);
4846

4947
assert(srcTy.getRank() == dstTy.getRank() &&
5048
"src and dst must have the same rank");

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion
174174
SmallVector<unsigned> outNumCTAsEachRep(rank);
175175
SmallVector<unsigned> inNumCTAs(rank);
176176
SmallVector<unsigned> outNumCTAs(rank);
177-
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
178-
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
177+
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
178+
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
179179
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
180180

181181
for (unsigned d = 0; d < rank; ++d) {

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ struct ReduceOpConversion
421421
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
422422
resultLayout, resultTy, true);
423423
auto resultShape = resultTy.getShape();
424-
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
424+
auto resultCTATile = getShapePerCTATile(resultLayout);
425425
assert(resultIndices.size() == resultElems);
426426

427427
SmallVector<Value> resultVals(resultElems);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 19 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,25 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
204204
}
205205
return ret;
206206
}
207-
208-
SmallVector<unsigned> getShapePerCTATile(Attribute layout,
209-
ArrayRef<int64_t> tensorShape) {
207+
SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
210208
if (auto distributedLayout =
211209
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
212-
return distributedLayout.getShapePerCTATile(tensorShape);
210+
auto sizePerThread = distributedLayout.getSizePerThread();
211+
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
212+
// ThreadsPerWarp does not align with this function for slice layout
213+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214+
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
215+
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
216+
}
217+
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
218+
assert(sizePerThread.size() == threadsPerWarp.size() &&
219+
sizePerThread.size() == warpsPerCTA.size());
220+
SmallVector<unsigned> shape;
221+
for (auto [size, thread, warp] :
222+
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
223+
shape.push_back(size * thread * warp);
224+
}
225+
return shape;
213226
} else {
214227
llvm::report_fatal_error("getShapePerCTATile not implemented");
215228
return SmallVector<unsigned>();
@@ -691,14 +704,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
691704
SmallVector<unsigned> BlockedEncodingAttr::getSizePerThread() const {
692705
return SmallVector<unsigned>(getSizePerThread__());
693706
}
694-
SmallVector<unsigned>
695-
BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
696-
SmallVector<unsigned> shape;
697-
for (unsigned d = 0, n = getOrder().size(); d < n; ++d)
698-
shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] *
699-
getWarpsPerCTA()[d]);
700-
return shape;
701-
}
702707

703708
template <class T>
704709
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
@@ -800,12 +805,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
800805
sizePerThread.erase(sizePerThread.begin() + getDim());
801806
return sizePerThread;
802807
}
803-
SmallVector<unsigned>
804-
SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
805-
SmallVector<unsigned> shape = ::getShapePerCTATile(getParent(), tensorShape);
806-
shape.erase(shape.begin() + getDim());
807-
return shape;
808-
}
809808

810809
//
811810

@@ -999,9 +998,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
999998
}
1000999
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
10011000
auto shapePerCTA = getShapePerCTA(*this, shape);
1002-
auto shapePerCTATile = ::getShapePerCTATile(blockedLayout);
1001+
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
10031002
auto order = blockedLayout.getOrder();
1004-
auto sizePerThread = ::getSizePerThread(blockedLayout);
1003+
auto sizePerThread = blockedLayout.getSizePerThread();
10051004

10061005
int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0];
10071006
int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0];
@@ -1072,19 +1071,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10721071
/*kMajor*/ true);
10731072
}
10741073
}
1075-
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
1076-
ArrayRef<int64_t> tensorShape) const {
1077-
auto parentLayout = getParent();
1078-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1079-
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1080-
return parentMmaLayout.getShapePerCTATileForOperand(
1081-
tensorShape, getKWidth(), getOpIdx());
1082-
} else {
1083-
llvm::report_fatal_error(
1084-
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1085-
"supported yet");
1086-
}
1087-
}
10881074

10891075
LogicalResult DotOperandEncodingAttr::verify(
10901076
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1606,16 +1592,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
16061592
//===----------------------------------------------------------------------===//
16071593
// TODO: there is a lot of common code with MmaEncoding here
16081594

1609-
SmallVector<unsigned>
1610-
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1611-
auto warpsPerCTA = getWarpsPerCTA();
1612-
auto rank = warpsPerCTA.size();
1613-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1614-
shapePerCTATile[rank - 1] *= getMDim();
1615-
shapePerCTATile[rank - 2] *= getNDim();
1616-
return shapePerCTATile;
1617-
}
1618-
16191595
SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAsPerCGA() const {
16201596
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
16211597
}
@@ -1759,43 +1735,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17591735
return sizePerThread;
17601736
}
17611737

1762-
SmallVector<unsigned>
1763-
AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1764-
int kWidth, int opIdx) const {
1765-
assert(getMDim() == 32 || getMDim() == 16);
1766-
auto parentShapePerCTATile = getShapePerCTATile(shape);
1767-
auto rank = parentShapePerCTATile.size();
1768-
if (opIdx == 0) {
1769-
if (rank == 2)
1770-
return {parentShapePerCTATile[rank - 2], 32};
1771-
else
1772-
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32};
1773-
} else if (opIdx == 1) {
1774-
if (rank == 2)
1775-
return {32, parentShapePerCTATile[rank - 1]};
1776-
else
1777-
return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]};
1778-
} else {
1779-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1780-
}
1781-
llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1");
1782-
}
1783-
17841738
//===----------------------------------------------------------------------===//
17851739
// Wmma encoding
17861740
//===----------------------------------------------------------------------===//
17871741

1788-
SmallVector<unsigned>
1789-
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1790-
auto warpsPerCTA = getWarpsPerCTA();
1791-
auto rank = warpsPerCTA.size();
1792-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
1793-
1794-
auto mnkDim = getMNKDimPerInstr();
1795-
shapePerCTATile[rank - 2] *= mnkDim[0];
1796-
shapePerCTATile[rank - 1] *= mnkDim[1];
1797-
return shapePerCTATile;
1798-
}
17991742
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
18001743
auto rank = getWarpsPerCTA().size();
18011744
return getMatrixOrder(rank, /*rowMajor*/ true);
@@ -1860,21 +1803,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18601803
return sizePerThread;
18611804
}
18621805

1863-
SmallVector<unsigned>
1864-
AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
1865-
int kWidth, int opIdx) const {
1866-
auto parentShapePerCTA = getShapePerCTATile(shape);
1867-
auto rank = shape.size();
1868-
assert(rank == 2);
1869-
if (opIdx == 0) {
1870-
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
1871-
} else if (opIdx == 1) {
1872-
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
1873-
} else {
1874-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
1875-
}
1876-
}
1877-
18781806
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
18791807
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
18801808
auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx);
@@ -1993,24 +1921,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19931921
llvm_unreachable("Unexpected mma version");
19941922
}
19951923

1996-
SmallVector<unsigned>
1997-
NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
1998-
if (isAmpere()) {
1999-
auto warpsPerCTA = getWarpsPerCTA();
2000-
auto rank = warpsPerCTA.size();
2001-
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(),
2002-
warpsPerCTA.end());
2003-
shapePerCTATile[rank - 1] *= 8;
2004-
shapePerCTATile[rank - 2] *= 16;
2005-
return shapePerCTATile;
2006-
}
2007-
if (isHopper()) {
2008-
auto instrShape = getInstrShape();
2009-
return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]};
2010-
}
2011-
llvm::report_fatal_error("Unexpected MMA layout version found");
2012-
}
2013-
20141924
SmallVector<unsigned>
20151925
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20161926
auto rank = getWarpsPerCTA().size();
@@ -2051,16 +1961,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20511961
}
20521962
}
20531963

2054-
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
2055-
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
2056-
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
2057-
auto shapePerCTATile = getShapePerCTATile(shape);
2058-
auto rank = shapePerCTATile.size();
2059-
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
2060-
// 4 threads * 2 subtiles
2061-
shapePerCTATile[kDim] = kWidth * 2 * 4;
2062-
return shapePerCTATile;
2063-
}
20641964
SmallVector<unsigned>
20651965
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
20661966
auto rank = getWarpsPerCTA().size();

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() {
7878
}
7979

8080
auto srcShape = srcTy.getShape();
81-
auto shapePerCTATile =
82-
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
81+
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
8382
shapePerCTATile[0] =
8483
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
8584
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct ExtractSliceOpConversion
7070
auto order = triton::gpu::getOrder(srcLayout);
7171

7272
// Calculate valid total number of workers in each dimension
73-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
73+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
7474
shapePerCTATile[0] =
7575
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
7676
shapePerCTATile[1] =

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
4848
} else {
4949
warpOrder = triton::gpu::getWarpOrder(layout);
5050
}
51-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
51+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
5252
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
5353
Value laneId = urem(tid, warpSize);
5454
Value warpId = udiv(tid, warpSize);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize;
1717
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1818
using ::mlir::triton::gpu::getOrder;
1919
using ::mlir::triton::gpu::getShapePerCTA;
20-
using ::mlir::triton::gpu::getShapePerCTATile;
2120
using ::mlir::triton::gpu::getSizePerThread;
2221
using ::mlir::triton::gpu::getTotalElemsPerThread;
2322
using ::mlir::triton::gpu::SharedEncodingAttr;

0 commit comments

Comments
 (0)