Skip to content

Commit 0d118de

Browse files
Revert "Revert "[Backend] Convert FMA dot operand to linear layout (#5469)""
This reverts commit b6f6b41.
1 parent db993b6 commit 0d118de

File tree

6 files changed

+225
-94
lines changed

6 files changed

+225
-94
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
9898
}
9999
}
100100

101-
void verifyCTALayout(CTALayoutAttr ctaLayout) {
101+
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
102102
auto ctaSplit = ctaLayout.getCTASplitNum();
103103
for (auto split : ctaSplit) {
104104
if (split != 1)
105-
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
106-
"are not supported in FMA dot yet.");
105+
return false;
107106
}
107+
return true;
108108
}
109109

110110
/// Get a linear offset of first element loaded by thread.
@@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
216216
Value thread, Location loc,
217217
const LLVMTypeConverter *typeConverter,
218218
ConversionPatternRewriter &rewriter, const int dotOpNo) {
219-
verifyCTALayout(dLayout.getCTALayout());
219+
if (!verifyCTALayout(dLayout.getCTALayout()))
220+
return Value();
220221

221222
DimIdx dim;
222223
dim.batch = 0;
@@ -292,6 +293,15 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
292293
auto numBTiles = std::max(1u, B / shapePerCTABTile);
293294
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);
294295

296+
// Found discrepancy in this case,
297+
// use linear layout based converter for this case
298+
// TODO: break batch and non-k dimension iterations in
299+
// "repeat" and "inside-repeate" parts, pack them in llvm structure
300+
// according repeat and register order.
301+
// See FMA.cpp:getValueTableFromStructFMA for reference
302+
if (numBTiles != 1 || numNonKTiles != 1)
303+
return Value();
304+
295305
auto perThreadShape =
296306
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);
297307

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,51 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
1313
using ::mlir::triton::gpu::getShapePerCTA;
1414
using ::mlir::triton::gpu::getSizePerThread;
1515

16-
using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
16+
/// \brief spatial position of repetition and register of a given value
17+
struct OperandValueKey {
18+
unsigned bRepIdx, nonKRepIdx;
19+
unsigned bIdx, nonKIdx, kIdx;
20+
21+
bool operator==(const OperandValueKey &other) const {
22+
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23+
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24+
kIdx == other.kIdx);
25+
}
26+
};
27+
28+
template <> struct std::hash<OperandValueKey> {
29+
std::size_t operator()(const OperandValueKey &k) const {
30+
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
31+
k.kIdx);
32+
}
33+
};
34+
35+
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
1736

18-
static ValueTableFMA
19-
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
20-
unsigned kDim, unsigned nonKDim,
21-
ConversionPatternRewriter &rewriter, Location loc,
22-
ArrayRef<unsigned> order) {
37+
static ValueTableFMA getValueTableFromStructFMA(
38+
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
39+
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
40+
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
2341
ValueTableFMA res;
2442
auto elems = unpackLLElements(loc, val, rewriter);
25-
assert(perTileShape.size() == 3);
26-
assert(elems.size() == product(perTileShape));
43+
assert(perRepShape.size() == 3);
44+
auto numElemsRep = product(perRepShape);
45+
assert(elems.size() == numElemsRep * product(repetitions));
2746
assert(kDim == 1 || kDim == 2);
2847
assert(nonKDim == 1 || nonKDim == 2);
2948
const unsigned bDim = 0;
3049

3150
for (unsigned idx = 0; idx < elems.size(); ++idx) {
32-
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
33-
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
51+
auto inRepLinearIdx = idx % numElemsRep;
52+
auto repLinearIdx = idx / numElemsRep;
53+
auto inRepSpatialIdx =
54+
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
55+
auto repSpatialIdx =
56+
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
57+
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
58+
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
59+
inRepSpatialIdx[kDim]};
60+
res[key] = elems[idx];
3461
}
3562
return res;
3663
}
@@ -54,46 +81,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
5481

5582
BlockedEncodingAttr dLayout =
5683
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
57-
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
84+
// TODO process A and B operand separately
85+
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
86+
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
5887
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
5988

6089
Value llA = adaptor.getA();
6190
Value llB = adaptor.getB();
6291

6392
auto sizePerThread =
6493
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
94+
auto numElemsPerThread = product(sizePerThread);
6595
auto shapePerCTATile =
6696
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
6797

6898
unsigned K = aShapePerCTA[2];
6999

70-
unsigned perThreadShape[3];
100+
unsigned threadTileShape[3];
101+
unsigned repetitions[3];
71102
for (int i = 0; i < 3; ++i) {
72-
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
73-
numRep = std::max(static_cast<unsigned>(1), numRep);
74-
perThreadShape[i] = numRep * sizePerThread[i];
103+
repetitions[i] =
104+
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
75105
}
76106

77107
auto has = getValueTableFromStructFMA(
78-
llA, {perThreadShape[0], perThreadShape[1], K},
79-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
108+
llA, {sizePerThread[0], sizePerThread[1], K},
109+
{repetitions[0], repetitions[1], 1},
110+
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
80111
auto hbs = getValueTableFromStructFMA(
81-
llB, {perThreadShape[0], K, perThreadShape[2]},
82-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
112+
llB, {sizePerThread[0], K, sizePerThread[2]},
113+
{repetitions[0], 1, repetitions[2]},
114+
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
83115

84116
SmallVector<Value> acc = cc;
85117

86-
for (unsigned b = 0; b < perThreadShape[0]; ++b)
87-
for (unsigned m = 0; m < perThreadShape[1]; ++m)
88-
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
89-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
90-
unsigned linearAccumIdx =
91-
linearize(multiDimAccumIdx, perThreadShape, order);
92-
for (unsigned k = 0; k < K; ++k) {
93-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
94-
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
95-
}
96-
}
118+
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
119+
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
120+
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
121+
for (unsigned b = 0; b < sizePerThread[0]; ++b)
122+
for (unsigned m = 0; m < sizePerThread[1]; ++m)
123+
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
124+
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
125+
unsigned linearInRepIdx =
126+
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
127+
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
128+
unsigned linearRepIdx =
129+
linearize(multiDimRepIdx, repetitions, repOrder);
130+
unsigned linearAccumIdx =
131+
linearInRepIdx + linearRepIdx * numElemsPerThread;
132+
for (unsigned k = 0; k < K; ++k) {
133+
auto aOp = has[{bRep, mRep, b, m, k}];
134+
auto bOp = hbs[{bRep, nRep, b, n, k}];
135+
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
136+
loc, aOp, bOp, acc[linearAccumIdx]);
137+
}
138+
}
97139

98140
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
99141
rewriter.replaceOp(op, res);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -119,54 +119,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
119119
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
120120
}
121121

122-
// FIXME [Dot LL]
123-
// Do for all DotOperandEncodingAttr once we have LLs for all of them
124-
static bool isSupportedLayout(Attribute dstLayout) {
125-
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
126-
LinearEncodingAttr>(dstLayout))
127-
return true;
128-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
129-
if (isa<MmaEncodingTrait>(dot.getParent()))
130-
return true;
131-
}
132-
return false;
133-
};
134-
135122
LogicalResult
136123
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
137124
ConversionPatternRewriter &rewriter) const override {
138-
RankedTensorType dstTy = op.getType();
139-
Attribute dstLayout = dstTy.getEncoding();
140-
if (isSupportedLayout(dstLayout)) {
141-
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
142-
rewriter);
143-
}
144-
if (isa<DotOperandEncodingAttr>(dstLayout) &&
145-
isa<BlockedEncodingAttr>(
146-
cast<DotOperandEncodingAttr>(dstLayout).getParent())) {
147-
return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter);
148-
}
149-
return failure();
125+
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
150126
}
151127

152128
private:
153-
LogicalResult
154-
lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
155-
const LLVMTypeConverter *typeConverter,
156-
ConversionPatternRewriter &rewriter) const {
157-
auto loc = op.getLoc();
158-
RankedTensorType dstTy = op.getType();
159-
Attribute dstLayout = dstTy.getEncoding();
160-
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
161-
auto blockedLayout = cast<BlockedEncodingAttr>(
162-
cast<DotOperandEncodingAttr>(dstLayout).getParent());
163-
auto thread = getThreadId(rewriter, loc);
164-
Value res = SharedToDotOperandFMA::convertLayout(
165-
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
166-
thread, loc, getTypeConverter(), rewriter);
167-
rewriter.replaceOp(op, res);
168-
return success();
169-
}
170129
LogicalResult
171130
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
172131
const LLVMTypeConverter *typeConverter,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,11 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
240240
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
241241
}
242242

243-
LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
244-
ArrayRef<unsigned> warpOrder, unsigned inner) {
243+
/// Function to generate lane and warp layout for dot operands.
244+
LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
245+
ArrayRef<unsigned> shape,
246+
ArrayRef<unsigned> order,
247+
unsigned kDim, StringAttr inDimName) {
245248
// Let warpsPerCTAMma = {2, 2}, then
246249
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
247250
// assume warpOrder = {1, 0}
@@ -256,24 +259,23 @@ LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
256259
// - - | - - - - | - -
257260
// 2 3 | 2 3 0 2 | 1 3
258261
// In other words, we need to broadcast along K
259-
auto rank = warpShape.size();
262+
auto rank = shape.size();
260263
auto dimNames = standardOutDimNames(ctx, rank);
261-
LinearLayout warpLayout = LinearLayout::empty();
264+
LinearLayout layout = LinearLayout::empty();
262265

263266
// We have to broadcast along the inner dimension
264267
// For A, when moving along M we go from 0 to 2.
265268
// For B, when moving along N we go from 0 to 1.
266269
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
267270
// Same happens if the warpOrder is {0, 1}, like in Hopper
268-
for (auto d : warpOrder) {
269-
if (d == inner) {
270-
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]);
271+
for (auto d : order) {
272+
if (d == kDim) {
273+
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
271274
} else {
272-
warpLayout *=
273-
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]);
275+
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
274276
}
275277
}
276-
return warpLayout;
278+
return layout;
277279
}
278280

279281
} // anonymous namespace
@@ -621,7 +623,8 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
621623
// Generate warp layout
622624
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
623625
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout);
624-
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim);
626+
LinearLayout warpLayout =
627+
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp"));
625628

626629
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
627630
// extension of layout
@@ -651,6 +654,48 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
651654
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
652655
}
653656

657+
std::optional<LinearLayout>
658+
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
659+
ArrayRef<int64_t> shape) {
660+
int rank = shape.size();
661+
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
662+
MLIRContext *ctx = operandLayout.getContext();
663+
664+
// TODO: introduce registerOrder or use getOrder(operandLayout)
665+
// Currently this order is used in legacy converter, because we do not
666+
// have access to full dot operand layout, only parent part.
667+
auto regOrder = blocked.getOrder();
668+
// TODO: use operandLayout.getThreadOrder()
669+
auto threadOrder = blocked.getThreadOrder();
670+
auto warpOrder = blocked.getWarpOrder();
671+
auto repOrder = blocked.getRepOrder();
672+
673+
StringAttr kReg = S("register");
674+
StringAttr kLane = S("lane");
675+
StringAttr kWarp = S("warp");
676+
677+
SmallVector<unsigned> threadSize = blocked.getSizePerThread();
678+
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
679+
threadSize[kDimIdx] = shape[kDimIdx];
680+
auto threadShape = blocked.getThreadsPerWarp();
681+
auto warpShape = blocked.getWarpsPerCTA();
682+
683+
SmallVector<StringAttr> repDimNames =
684+
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);
685+
686+
auto registersLayout = identityStandardND(kReg, threadSize, regOrder);
687+
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder,
688+
kDimIdx, kLane);
689+
auto warpsLayout =
690+
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp);
691+
692+
LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) *
693+
lanesLayout.transposeOuts(repDimNames) *
694+
warpsLayout.transposeOuts(repDimNames);
695+
696+
return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
697+
}
698+
654699
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
655700
unsigned kWidth, ArrayRef<unsigned> order,
656701
ArrayRef<unsigned> repOrder) {
@@ -741,19 +786,21 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
741786
auto ctaLayout =
742787
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
743788
auto kDim = isA ? rank - 1 : rank - 2;
744-
ctaLayout *=
745-
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim)
746-
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
789+
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(),
790+
mma.getWarpOrder(), kDim, S("warp"))
791+
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
747792

748793
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
749794
}
750795

751796
std::optional<LinearLayout>
752797
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
753798
auto parent = getParent();
754-
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
799+
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
800+
return fmaDotToLinearLayout(*this, shape);
801+
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
755802
return mfmaDotToLinearLayout(*this, shape);
756-
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
803+
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
757804
return wmmaDotOperandToLinearLayout(*this, shape);
758805
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
759806
return nvidiaDotToLinearLayout(shape, *this);

test/Conversion/amd/decompose-unsupported-conversions.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
9797
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
9898
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
9999
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
100-
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
100+
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
101101
// CHECK-NOT: ttg.convert_layout
102102
// CHECK: ttg.local_alloc
103103
// CHECK: ttg.local_load
104-
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
104+
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
105105
tt.return
106106
}
107107
}

0 commit comments

Comments
 (0)