Skip to content

Commit ae7a689

Browse files
Merge commit '4d1ec3eef217053f13c490fe9187539289cd184a'
2 parents 08cad31 + 4d1ec3e commit ae7a689

File tree

26 files changed

+658
-212
lines changed

26 files changed

+658
-212
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ python/triton/backends/
3232
# Language extras
3333
python/triton/language/extra
3434

35+
# Tools extras
36+
python/triton/tools/extra
37+
3538
# Proton
3639
python/triton/profiler
3740

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -892,9 +892,15 @@ inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout,
892892
if (rank == 3)
893893
elemOffset[0] = ctaBatchOffset;
894894
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
895-
elemOffset[rank - 2] =
896-
ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem;
897-
elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
895+
if (wmmaLayout.getIsTransposed()) {
896+
elemOffset[rank - 1] =
897+
ctaOffsetX * shapePerCta[rank - 1] + elemStride * elem;
898+
elemOffset[rank - 2] = ctaOffsetY * shapePerCta[rank - 2];
899+
} else {
900+
elemOffset[rank - 2] =
901+
ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem;
902+
elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
903+
}
898904
offsets.push_back(elemOffset);
899905
}
900906
}
@@ -945,10 +951,19 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
945951
add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0);
946952
} else {
947953
assert(ver == 2);
948-
multiDimBase[rank - 2] =
949-
add(mul(udiv(threadIdPerWarp, i32_val(mnkDim[2])),
950-
i32_val(wmmaLayout.getSizePerThread()[rank - 2])),
951-
offWarp0);
954+
if (wmmaLayout.getIsTransposed()) {
955+
multiDimBase[rank - 1] =
956+
add(mul(udiv(threadIdPerWarp, i32_val(16)),
957+
i32_val(wmmaLayout.getSizePerThread()[rank - 1])),
958+
offWarp1);
959+
multiDimBase[rank - 2] = add(laneId, offWarp0);
960+
} else {
961+
multiDimBase[rank - 2] =
962+
add(mul(udiv(threadIdPerWarp, i32_val(16)),
963+
i32_val(wmmaLayout.getSizePerThread()[rank - 2])),
964+
offWarp0);
965+
multiDimBase[rank - 1] = add(laneId, offWarp1);
966+
}
952967
}
953968
multiDimBase[rank - 1] = add(laneId, offWarp1);
954969

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ is supported.
961961

962962
// ----------------------------------- version = 1 ----------------------------------- //
963963

964-
Row | warp 0 warp 2
964+
Row | warp 0 warp 1
965965
|/-------------------^-------------------\ /-------------------^-------------------\
966966
0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
967967
1 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
@@ -971,7 +971,7 @@ Row | warp 0 warp 2
971971
14 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
972972
15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
973973

974-
| warp 1 warp 3
974+
| warp 2 warp 3
975975
16 |/-------------------^-------------------\ /-------------------^-------------------\
976976
17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
977977
18 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
@@ -981,9 +981,9 @@ Row | warp 0 warp 2
981981
30 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
982982
31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
983983

984-
// ----------------------------------- version = 2 ----------------------------------- //
984+
// ------------------------ version = 2, isTransposed = false ------------------------ //
985985

986-
Row | warp 0 warp 2
986+
Row | warp 0 warp 1
987987
|/--------^---------\ /---------^--------\
988988
0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15]
989989
1 |[0 1 2 ... 14 15] [0 1 2 ... 14 15]
@@ -996,7 +996,7 @@ Row | warp 0 warp 2
996996
14 |[16 17 18 ... 30 31] [16 17 18 ... 30 31]
997997
15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31]
998998
|
999-
| warp 1 warp 3
999+
| warp 2 warp 3
10001000
|/--------^---------\ /---------^--------\
10011001
16 |[0 1 2 ... 14 15] [0 1 2 ... 14 15]
10021002
17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15]
@@ -1008,15 +1008,37 @@ Row | warp 0 warp 2
10081008
.. | ... ...
10091009
30 |[16 17 18 ... 30 31] [16 17 18 ... 30 31]
10101010
31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31]
1011+
1012+
// ------------------------ version = 2, isTransposed = true ------------------------ //
1013+
1014+
| warp 0 warp 1
1015+
|/----------------^----------------\ /-------^-------\
1016+
Col>| 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 ... 32
1017+
Row |
1018+
0 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16]
1019+
1 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17]
1020+
.. | ... ...
1021+
14 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30]
1022+
15 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31]
1023+
|
1024+
| warp 2 warp 3
1025+
|/----------------^----------------\ /-------^-------\
1026+
16 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16]
1027+
17 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17]
1028+
.. | ... ...
1029+
30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30]
1030+
31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31]
10111031
}];
10121032

10131033
let parameters = (
10141034
ins
10151035
"unsigned": $version,
1036+
"bool":$isTransposed,
10161037
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
10171038
"CTALayoutAttr":$CTALayout
10181039
);
10191040

1041+
let genVerifyDecl = 1;
10201042
let hasCustomAssemblyFormat = 1;
10211043

10221044
let extraClassDeclaration = extraDistributedDeclaration # [{

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
9898
}
9999
}
100100

101-
void verifyCTALayout(CTALayoutAttr ctaLayout) {
101+
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
102102
auto ctaSplit = ctaLayout.getCTASplitNum();
103103
for (auto split : ctaSplit) {
104104
if (split != 1)
105-
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
106-
"are not supported in FMA dot yet.");
105+
return false;
107106
}
107+
return true;
108108
}
109109

110110
/// Get a linear offset of first element loaded by thread.
@@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
216216
Value thread, Location loc,
217217
const LLVMTypeConverter *typeConverter,
218218
ConversionPatternRewriter &rewriter, const int dotOpNo) {
219-
verifyCTALayout(dLayout.getCTALayout());
219+
if (!verifyCTALayout(dLayout.getCTALayout()))
220+
return Value();
220221

221222
DimIdx dim;
222223
dim.batch = 0;
@@ -292,6 +293,15 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
292293
auto numBTiles = std::max(1u, B / shapePerCTABTile);
293294
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);
294295

296+
// Found discrepancy in this case,
297+
// use linear layout based converter for this case
298+
// TODO: break batch and non-k dimension iterations in
299+
// "repeat" and "inside-repeate" parts, pack them in llvm structure
300+
// according repeat and register order.
301+
// See FMA.cpp:getValueTableFromStructFMA for reference
302+
if (numBTiles != 1 || numNonKTiles != 1)
303+
return Value();
304+
295305
auto perThreadShape =
296306
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);
297307

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,51 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
1313
using ::mlir::triton::gpu::getShapePerCTA;
1414
using ::mlir::triton::gpu::getSizePerThread;
1515

16-
using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
16+
/// \brief spatial position of repetition and register of a given value
17+
struct OperandValueKey {
18+
unsigned bRepIdx, nonKRepIdx;
19+
unsigned bIdx, nonKIdx, kIdx;
20+
21+
bool operator==(const OperandValueKey &other) const {
22+
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23+
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24+
kIdx == other.kIdx);
25+
}
26+
};
27+
28+
template <> struct std::hash<OperandValueKey> {
29+
std::size_t operator()(const OperandValueKey &k) const {
30+
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
31+
k.kIdx);
32+
}
33+
};
34+
35+
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
1736

18-
static ValueTableFMA
19-
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
20-
unsigned kDim, unsigned nonKDim,
21-
ConversionPatternRewriter &rewriter, Location loc,
22-
ArrayRef<unsigned> order) {
37+
static ValueTableFMA getValueTableFromStructFMA(
38+
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
39+
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
40+
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
2341
ValueTableFMA res;
2442
auto elems = unpackLLElements(loc, val, rewriter);
25-
assert(perTileShape.size() == 3);
26-
assert(elems.size() == product(perTileShape));
43+
assert(perRepShape.size() == 3);
44+
auto numElemsRep = product(perRepShape);
45+
assert(elems.size() == numElemsRep * product(repetitions));
2746
assert(kDim == 1 || kDim == 2);
2847
assert(nonKDim == 1 || nonKDim == 2);
2948
const unsigned bDim = 0;
3049

3150
for (unsigned idx = 0; idx < elems.size(); ++idx) {
32-
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
33-
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
51+
auto inRepLinearIdx = idx % numElemsRep;
52+
auto repLinearIdx = idx / numElemsRep;
53+
auto inRepSpatialIdx =
54+
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
55+
auto repSpatialIdx =
56+
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
57+
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
58+
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
59+
inRepSpatialIdx[kDim]};
60+
res[key] = elems[idx];
3461
}
3562
return res;
3663
}
@@ -54,46 +81,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
5481

5582
BlockedEncodingAttr dLayout =
5683
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
57-
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
84+
// TODO process A and B operand separately
85+
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
86+
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
5887
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
5988

6089
Value llA = adaptor.getA();
6190
Value llB = adaptor.getB();
6291

6392
auto sizePerThread =
6493
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
94+
auto numElemsPerThread = product(sizePerThread);
6595
auto shapePerCTATile =
6696
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
6797

6898
unsigned K = aShapePerCTA[2];
6999

70-
unsigned perThreadShape[3];
100+
unsigned threadTileShape[3];
101+
unsigned repetitions[3];
71102
for (int i = 0; i < 3; ++i) {
72-
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
73-
numRep = std::max(static_cast<unsigned>(1), numRep);
74-
perThreadShape[i] = numRep * sizePerThread[i];
103+
repetitions[i] =
104+
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
75105
}
76106

77107
auto has = getValueTableFromStructFMA(
78-
llA, {perThreadShape[0], perThreadShape[1], K},
79-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
108+
llA, {sizePerThread[0], sizePerThread[1], K},
109+
{repetitions[0], repetitions[1], 1},
110+
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
80111
auto hbs = getValueTableFromStructFMA(
81-
llB, {perThreadShape[0], K, perThreadShape[2]},
82-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
112+
llB, {sizePerThread[0], K, sizePerThread[2]},
113+
{repetitions[0], 1, repetitions[2]},
114+
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
83115

84116
SmallVector<Value> acc = cc;
85117

86-
for (unsigned b = 0; b < perThreadShape[0]; ++b)
87-
for (unsigned m = 0; m < perThreadShape[1]; ++m)
88-
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
89-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
90-
unsigned linearAccumIdx =
91-
linearize(multiDimAccumIdx, perThreadShape, order);
92-
for (unsigned k = 0; k < K; ++k) {
93-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
94-
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
95-
}
96-
}
118+
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
119+
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
120+
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
121+
for (unsigned b = 0; b < sizePerThread[0]; ++b)
122+
for (unsigned m = 0; m < sizePerThread[1]; ++m)
123+
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
124+
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
125+
unsigned linearInRepIdx =
126+
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
127+
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
128+
unsigned linearRepIdx =
129+
linearize(multiDimRepIdx, repetitions, repOrder);
130+
unsigned linearAccumIdx =
131+
linearInRepIdx + linearRepIdx * numElemsPerThread;
132+
for (unsigned k = 0; k < K; ++k) {
133+
auto aOp = has[{bRep, mRep, b, m, k}];
134+
auto bOp = hbs[{bRep, nRep, b, n, k}];
135+
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
136+
loc, aOp, bOp, acc[linearAccumIdx]);
137+
}
138+
}
97139

98140
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
99141
rewriter.replaceOp(op, res);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -119,54 +119,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
119119
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
120120
}
121121

122-
// FIXME [Dot LL]
123-
// Do for all DotOperandEncodingAttr once we have LLs for all of them
124-
static bool isSupportedLayout(Attribute dstLayout) {
125-
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
126-
LinearEncodingAttr>(dstLayout))
127-
return true;
128-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
129-
if (isa<MmaEncodingTrait>(dot.getParent()))
130-
return true;
131-
}
132-
return false;
133-
};
134-
135122
LogicalResult
136123
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
137124
ConversionPatternRewriter &rewriter) const override {
138-
RankedTensorType dstTy = op.getType();
139-
Attribute dstLayout = dstTy.getEncoding();
140-
if (isSupportedLayout(dstLayout)) {
141-
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
142-
rewriter);
143-
}
144-
if (isa<DotOperandEncodingAttr>(dstLayout) &&
145-
isa<BlockedEncodingAttr>(
146-
cast<DotOperandEncodingAttr>(dstLayout).getParent())) {
147-
return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter);
148-
}
149-
return failure();
125+
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
150126
}
151127

152128
private:
153-
LogicalResult
154-
lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
155-
const LLVMTypeConverter *typeConverter,
156-
ConversionPatternRewriter &rewriter) const {
157-
auto loc = op.getLoc();
158-
RankedTensorType dstTy = op.getType();
159-
Attribute dstLayout = dstTy.getEncoding();
160-
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
161-
auto blockedLayout = cast<BlockedEncodingAttr>(
162-
cast<DotOperandEncodingAttr>(dstLayout).getParent());
163-
auto thread = getThreadId(rewriter, loc);
164-
Value res = SharedToDotOperandFMA::convertLayout(
165-
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
166-
thread, loc, getTypeConverter(), rewriter);
167-
rewriter.replaceOp(op, res);
168-
return success();
169-
}
170129
LogicalResult
171130
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
172131
const LLVMTypeConverter *typeConverter,

0 commit comments

Comments
 (0)