Skip to content

Commit a8d53ad

Browse files
lezcanoloislo
authored andcommitted
[LAYOUTS] Move all get.*ContigPerThread functions to a common API (triton-lang#6002)
There were a couple things left to clean up after triton-lang#5840. Now we provide a common API in terms of RankedTensorType.
1 parent 3bea6fc commit a8d53ad

File tree

7 files changed

+19
-93
lines changed

7 files changed

+19
-93
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
101101
if (!axisInfo)
102102
// axis info (e.g., constancy) not available
103103
return resultVals;
104-
SmallVector<unsigned> contigPerThread = getContigPerThread(encoding);
104+
SmallVector<unsigned> contigPerThread = getContigPerThread(rtType);
105105
if (rank != contigPerThread.size())
106106
return resultVals;
107107

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,13 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
104104

105105
SmallVector<unsigned> getSizePerThread(Attribute layout);
106106

107-
// Returns the number of contiguous elements that each thread
108-
// has access to, on each dimension of the tensor. E.g.
109-
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
110-
// regardless of the shape of the tensor.
111-
SmallVector<unsigned> getContigPerThread(Attribute layout);
112-
113-
// Returns the number of non-replicated contiguous elements that each thread
114-
// has access to, on each dimension of the tensor. For a blocked layout
107+
// Returns the number of contiguous elements of the logical tensor that each
108+
// thread has access to, on each dimension of the tensor. For a blocked layout
115109
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
116110
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
117111
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
118112
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
119-
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
120-
ArrayRef<int64_t> tensorShape);
113+
SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType);
121114

122115
// Returns the number of threads per warp that have access to non-replicated
123116
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,6 @@ We call each individual tile "rep".
532532
InterfaceMethod<"Get the shape of the values per thread.",
533533
"SmallVector<unsigned>",
534534
"getSizePerThread">,
535-
536-
InterfaceMethod<"Gets the number of contiguous elements per thread.",
537-
"SmallVector<unsigned>",
538-
"getContigPerThread">,
539535
InterfaceMethod<"Convert to LinearLayout.",
540536
"LinearLayout",
541537
"toLinearLayout",
@@ -819,12 +815,7 @@ for
819815
}]>
820816
];
821817

822-
let extraClassDeclaration = extraDistributedDeclaration # [{
823-
SmallVector<unsigned> getContigPerThread() {
824-
// Block encoding is dense stride layout. The elements per thread are contiguous.
825-
return getSizePerThread();
826-
};
827-
}];
818+
let extraClassDeclaration = extraDistributedDeclaration;
828819

829820
let hasCustomAssemblyFormat = 1;
830821
}
@@ -972,17 +963,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
972963
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
973964
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
974965
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
975-
976-
SmallVector<unsigned> getContigPerThread() {
977-
auto rank = getWarpsPerCTA().size();
978-
SmallVector<unsigned> contigPerThread(rank, 1);
979-
if (getIsTransposed())
980-
contigPerThread[rank - 1] = 4;
981-
else
982-
contigPerThread[rank - 2] = 4;
983-
return contigPerThread;
984-
};
985-
986966
}];
987967

988968
let genVerifyDecl = 1;
@@ -1100,16 +1080,6 @@ Row |
11001080
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11011081
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
11021082
static SmallVector<unsigned> getMNKDimPerInstr();
1103-
1104-
SmallVector<unsigned> getContigPerThread() {
1105-
auto rank = getWarpsPerCTA().size();
1106-
assert(rank == 2 || rank == 3);
1107-
SmallVector<unsigned> contigPerThread(rank, 1);
1108-
if (getVersion() == 2) {
1109-
contigPerThread[rank - 2] = 8;
1110-
}
1111-
return contigPerThread;
1112-
};
11131083
}];
11141084
}
11151085

@@ -1219,15 +1189,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12191189
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12201190
SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
12211191
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1222-
1223-
SmallVector<unsigned> getContigPerThread() {
1224-
assert(isAmpere() || isHopper());
1225-
auto rank = getWarpsPerCTA().size();
1226-
SmallVector<unsigned> contigPerThread(rank, 1);
1227-
contigPerThread[rank - 1] = 2;
1228-
return contigPerThread;
1229-
};
1230-
12311192
}];
12321193

12331194
let hasCustomAssemblyFormat = 1;
@@ -1273,13 +1234,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12731234
let extraClassDeclaration = extraDistributedDeclaration # [{
12741235
template<class T>
12751236
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
1276-
1277-
SmallVector<unsigned> getContigPerThread() {
1278-
auto parentLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1279-
auto parentContigPerThread = parentLayout.getContigPerThread();
1280-
parentContigPerThread.erase(parentContigPerThread.begin() + getDim());
1281-
return parentContigPerThread;
1282-
};
12831237
}];
12841238

12851239
let hasCustomAssemblyFormat = 1;
@@ -1347,20 +1301,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
13471301

13481302
let assemblyFormat = "`<` `{` struct(params) `}` `>`";
13491303
let genVerifyDecl = 1;
1350-
let extraClassDeclaration = extraDistributedDeclaration # [{
1351-
SmallVector<unsigned> getContigPerThread() {
1352-
auto rank = getWarpsPerCTA().size();
1353-
assert(rank == 2 || rank == 3);
1354-
SmallVector<unsigned> contigPerThread(rank, 1);
1355-
auto kWidth = getKWidth();
1356-
assert(kWidth != 0 && "Do not support kWidth=0");
1357-
if (getOpIdx() == 0)
1358-
contigPerThread[rank - 1] = kWidth;
1359-
else
1360-
contigPerThread[rank - 2] = kWidth;
1361-
return contigPerThread;
1362-
};
1363-
}];
1304+
let extraClassDeclaration = extraDistributedDeclaration;
13641305
}
13651306

13661307
def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,6 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
729729
return dotOperandLayout.getParent() == mfmaLayout &&
730730
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
731731
dotOperandLayout.getKWidth() == 8 &&
732-
getContigPerThread(mfmaLayout)[1] == 4 &&
733732
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
734733
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
735734
triton::type::isFloat8(srcTy.getElementType()) &&

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,11 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
113113
}
114114
}
115115

116-
SmallVector<unsigned> getContigPerThread(Attribute layout) {
117-
if (auto distributedLayout = dyn_cast<DistributedEncodingTrait>(layout)) {
118-
return distributedLayout.getContigPerThread();
119-
} else {
120-
llvm::report_fatal_error("getContigPerThread not implemented");
121-
return {};
122-
}
123-
}
124-
125-
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
126-
ArrayRef<int64_t> shape) {
116+
SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType) {
117+
auto layout = tensorType.getEncoding();
118+
auto shape = tensorType.getShape();
127119
auto linearLayout = toLinearLayout(shape, layout);
128-
auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout);
120+
auto llAttr = LinearEncodingAttr::get(tensorType.getContext(), linearLayout);
129121
return llAttr.getContigPerThread();
130122
}
131123

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/IR/PatternMatch.h"
66
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
77
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
89

910
using mlir::triton::ModuleAxisInfoAnalysis;
1011
using mlir::triton::AMD::DppCtrl;
@@ -536,12 +537,14 @@ unsigned getContiguity(Value ptr, Value offset,
536537
Type type = getPointerTypeWithShape(ptr, offset);
537538
RankedTensorType tensorTy = cast<RankedTensorType>(type);
538539
auto layout = tensorTy.getEncoding();
539-
auto order = triton::gpu::getOrder(layout);
540-
auto uniqueContigPerThread =
541-
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
542-
assert(order[0] < uniqueContigPerThread.size() &&
543-
"Unexpected uniqueContigPerThread size");
544-
unsigned contiguity = uniqueContigPerThread[order[0]];
540+
auto linearLayout = triton::gpu::toLinearLayout(tensorTy.getShape(), layout);
541+
auto llAttr =
542+
triton::gpu::LinearEncodingAttr::get(tensorTy.getContext(), linearLayout);
543+
auto order = llAttr.getOrder();
544+
auto contigPerThread = llAttr.getContigPerThread();
545+
assert(order[0] < contigPerThread.size() &&
546+
"Unexpected contigPerThread size");
547+
unsigned contiguity = contigPerThread[order[0]];
545548

546549
// Get alignment from the pointer. Since this is a scalar pointer
547550
// we should not take the pointer contiguity to consider alignment

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,6 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
723723
if (!is_dot_op_with_block_parent(distributedEncoding)) {
724724
ASSERT_EQ(distributedEncoding.getRepOrder(),
725725
linearEncoding.getRepOrder());
726-
ASSERT_EQ(distributedEncoding.getContigPerThread(),
727-
linearEncoding.getContigPerThread());
728726
}
729727
// DotOperandEncodingAttr::getWarpOrder() is not defined
730728
if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {

0 commit comments

Comments
 (0)