Skip to content

Commit b6f6b41

Browse files
Revert "[Backend] Convert FMA dot operand to linear layout (#5469)"
This reverts commit d9facf3.
1 parent ae7a689 commit b6f6b41

File tree

6 files changed

+94
-225
lines changed

6 files changed

+94
-225
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
9898
}
9999
}
100100

101-
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
101+
void verifyCTALayout(CTALayoutAttr ctaLayout) {
102102
auto ctaSplit = ctaLayout.getCTASplitNum();
103103
for (auto split : ctaSplit) {
104104
if (split != 1)
105-
return false;
105+
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
106+
"are not supported in FMA dot yet.");
106107
}
107-
return true;
108108
}
109109

110110
/// Get a linear offset of first element loaded by thread.
@@ -216,8 +216,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
216216
Value thread, Location loc,
217217
const LLVMTypeConverter *typeConverter,
218218
ConversionPatternRewriter &rewriter, const int dotOpNo) {
219-
if (!verifyCTALayout(dLayout.getCTALayout()))
220-
return Value();
219+
verifyCTALayout(dLayout.getCTALayout());
221220

222221
DimIdx dim;
223222
dim.batch = 0;
@@ -293,15 +292,6 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
293292
auto numBTiles = std::max(1u, B / shapePerCTABTile);
294293
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);
295294

296-
// Found discrepancy in this case,
297-
// use linear layout based converter for this case
298-
// TODO: break batch and non-k dimension iterations in
299-
// "repeat" and "inside-repeate" parts, pack them in llvm structure
300-
// according repeat and register order.
301-
// See FMA.cpp:getValueTableFromStructFMA for reference
302-
if (numBTiles != 1 || numNonKTiles != 1)
303-
return Value();
304-
305295
auto perThreadShape =
306296
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);
307297

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 30 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,24 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
1313
using ::mlir::triton::gpu::getShapePerCTA;
1414
using ::mlir::triton::gpu::getSizePerThread;
1515

16-
/// \brief spatial position of repetition and register of a given value
17-
struct OperandValueKey {
18-
unsigned bRepIdx, nonKRepIdx;
19-
unsigned bIdx, nonKIdx, kIdx;
20-
21-
bool operator==(const OperandValueKey &other) const {
22-
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23-
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24-
kIdx == other.kIdx);
25-
}
26-
};
27-
28-
template <> struct std::hash<OperandValueKey> {
29-
std::size_t operator()(const OperandValueKey &k) const {
30-
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
31-
k.kIdx);
32-
}
33-
};
34-
35-
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
16+
using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
3617

37-
static ValueTableFMA getValueTableFromStructFMA(
38-
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
39-
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
40-
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
18+
static ValueTableFMA
19+
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
20+
unsigned kDim, unsigned nonKDim,
21+
ConversionPatternRewriter &rewriter, Location loc,
22+
ArrayRef<unsigned> order) {
4123
ValueTableFMA res;
4224
auto elems = unpackLLElements(loc, val, rewriter);
43-
assert(perRepShape.size() == 3);
44-
auto numElemsRep = product(perRepShape);
45-
assert(elems.size() == numElemsRep * product(repetitions));
25+
assert(perTileShape.size() == 3);
26+
assert(elems.size() == product(perTileShape));
4627
assert(kDim == 1 || kDim == 2);
4728
assert(nonKDim == 1 || nonKDim == 2);
4829
const unsigned bDim = 0;
4930

5031
for (unsigned idx = 0; idx < elems.size(); ++idx) {
51-
auto inRepLinearIdx = idx % numElemsRep;
52-
auto repLinearIdx = idx / numElemsRep;
53-
auto inRepSpatialIdx =
54-
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
55-
auto repSpatialIdx =
56-
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
57-
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
58-
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
59-
inRepSpatialIdx[kDim]};
60-
res[key] = elems[idx];
32+
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
33+
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
6134
}
6235
return res;
6336
}
@@ -81,61 +54,46 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
8154

8255
BlockedEncodingAttr dLayout =
8356
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
84-
// TODO process A and B operand separately
85-
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
86-
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
57+
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
8758
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
8859

8960
Value llA = adaptor.getA();
9061
Value llB = adaptor.getB();
9162

9263
auto sizePerThread =
9364
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
94-
auto numElemsPerThread = product(sizePerThread);
9565
auto shapePerCTATile =
9666
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
9767

9868
unsigned K = aShapePerCTA[2];
9969

100-
unsigned threadTileShape[3];
101-
unsigned repetitions[3];
70+
unsigned perThreadShape[3];
10271
for (int i = 0; i < 3; ++i) {
103-
repetitions[i] =
104-
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
72+
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
73+
numRep = std::max(static_cast<unsigned>(1), numRep);
74+
perThreadShape[i] = numRep * sizePerThread[i];
10575
}
10676

10777
auto has = getValueTableFromStructFMA(
108-
llA, {sizePerThread[0], sizePerThread[1], K},
109-
{repetitions[0], repetitions[1], 1},
110-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
78+
llA, {perThreadShape[0], perThreadShape[1], K},
79+
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
11180
auto hbs = getValueTableFromStructFMA(
112-
llB, {sizePerThread[0], K, sizePerThread[2]},
113-
{repetitions[0], 1, repetitions[2]},
114-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
81+
llB, {perThreadShape[0], K, perThreadShape[2]},
82+
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
11583

11684
SmallVector<Value> acc = cc;
11785

118-
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
119-
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
120-
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
121-
for (unsigned b = 0; b < sizePerThread[0]; ++b)
122-
for (unsigned m = 0; m < sizePerThread[1]; ++m)
123-
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
124-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
125-
unsigned linearInRepIdx =
126-
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
127-
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
128-
unsigned linearRepIdx =
129-
linearize(multiDimRepIdx, repetitions, repOrder);
130-
unsigned linearAccumIdx =
131-
linearInRepIdx + linearRepIdx * numElemsPerThread;
132-
for (unsigned k = 0; k < K; ++k) {
133-
auto aOp = has[{bRep, mRep, b, m, k}];
134-
auto bOp = hbs[{bRep, nRep, b, n, k}];
135-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
136-
loc, aOp, bOp, acc[linearAccumIdx]);
137-
}
138-
}
86+
for (unsigned b = 0; b < perThreadShape[0]; ++b)
87+
for (unsigned m = 0; m < perThreadShape[1]; ++m)
88+
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
89+
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
90+
unsigned linearAccumIdx =
91+
linearize(multiDimAccumIdx, perThreadShape, order);
92+
for (unsigned k = 0; k < K; ++k) {
93+
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
94+
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
95+
}
96+
}
13997

14098
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
14199
rewriter.replaceOp(op, res);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,54 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
119119
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
120120
}
121121

122+
// FIXME [Dot LL]
123+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
124+
static bool isSupportedLayout(Attribute dstLayout) {
125+
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
126+
LinearEncodingAttr>(dstLayout))
127+
return true;
128+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
129+
if (isa<MmaEncodingTrait>(dot.getParent()))
130+
return true;
131+
}
132+
return false;
133+
};
134+
122135
LogicalResult
123136
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
124137
ConversionPatternRewriter &rewriter) const override {
125-
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
138+
RankedTensorType dstTy = op.getType();
139+
Attribute dstLayout = dstTy.getEncoding();
140+
if (isSupportedLayout(dstLayout)) {
141+
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
142+
rewriter);
143+
}
144+
if (isa<DotOperandEncodingAttr>(dstLayout) &&
145+
isa<BlockedEncodingAttr>(
146+
cast<DotOperandEncodingAttr>(dstLayout).getParent())) {
147+
return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter);
148+
}
149+
return failure();
126150
}
127151

128152
private:
153+
LogicalResult
154+
lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
155+
const LLVMTypeConverter *typeConverter,
156+
ConversionPatternRewriter &rewriter) const {
157+
auto loc = op.getLoc();
158+
RankedTensorType dstTy = op.getType();
159+
Attribute dstLayout = dstTy.getEncoding();
160+
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
161+
auto blockedLayout = cast<BlockedEncodingAttr>(
162+
cast<DotOperandEncodingAttr>(dstLayout).getParent());
163+
auto thread = getThreadId(rewriter, loc);
164+
Value res = SharedToDotOperandFMA::convertLayout(
165+
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
166+
thread, loc, getTypeConverter(), rewriter);
167+
rewriter.replaceOp(op, res);
168+
return success();
169+
}
129170
LogicalResult
130171
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
131172
const LLVMTypeConverter *typeConverter,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 16 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,8 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
240240
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
241241
}
242242

243-
/// Function to generate lane and warp layout for dot operands.
244-
LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
245-
ArrayRef<unsigned> shape,
246-
ArrayRef<unsigned> order,
247-
unsigned kDim, StringAttr inDimName) {
243+
LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
244+
ArrayRef<unsigned> warpOrder, unsigned inner) {
248245
// Let warpsPerCTAMma = {2, 2}, then
249246
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
250247
// assume warpOrder = {1, 0}
@@ -259,23 +256,24 @@ LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
259256
// - - | - - - - | - -
260257
// 2 3 | 2 3 0 2 | 1 3
261258
// In other words, we need to broadcast along K
262-
auto rank = shape.size();
259+
auto rank = warpShape.size();
263260
auto dimNames = standardOutDimNames(ctx, rank);
264-
LinearLayout layout = LinearLayout::empty();
261+
LinearLayout warpLayout = LinearLayout::empty();
265262

266263
// We have to broadcast along the inner dimension
267264
// For A, when moving along M we go from 0 to 2.
268265
// For B, when moving along N we go from 0 to 1.
269266
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
270267
// Same happens if the warpOrder is {0, 1}, like in Hopper
271-
for (auto d : order) {
272-
if (d == kDim) {
273-
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
268+
for (auto d : warpOrder) {
269+
if (d == inner) {
270+
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]);
274271
} else {
275-
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
272+
warpLayout *=
273+
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]);
276274
}
277275
}
278-
return layout;
276+
return warpLayout;
279277
}
280278

281279
} // anonymous namespace
@@ -623,8 +621,7 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
623621
// Generate warp layout
624622
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
625623
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout);
626-
LinearLayout warpLayout =
627-
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp"));
624+
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim);
628625

629626
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
630627
// extension of layout
@@ -654,48 +651,6 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
654651
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
655652
}
656653

657-
std::optional<LinearLayout>
658-
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
659-
ArrayRef<int64_t> shape) {
660-
int rank = shape.size();
661-
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
662-
MLIRContext *ctx = operandLayout.getContext();
663-
664-
// TODO: introduce registerOrder or use getOrder(operandLayout)
665-
// Currently this order is used in legacy converter, because we do not
666-
// have access to full dot operand layout, only parent part.
667-
auto regOrder = blocked.getOrder();
668-
// TODO: use operandLayout.getThreadOrder()
669-
auto threadOrder = blocked.getThreadOrder();
670-
auto warpOrder = blocked.getWarpOrder();
671-
auto repOrder = blocked.getRepOrder();
672-
673-
StringAttr kReg = S("register");
674-
StringAttr kLane = S("lane");
675-
StringAttr kWarp = S("warp");
676-
677-
SmallVector<unsigned> threadSize = blocked.getSizePerThread();
678-
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
679-
threadSize[kDimIdx] = shape[kDimIdx];
680-
auto threadShape = blocked.getThreadsPerWarp();
681-
auto warpShape = blocked.getWarpsPerCTA();
682-
683-
SmallVector<StringAttr> repDimNames =
684-
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);
685-
686-
auto registersLayout = identityStandardND(kReg, threadSize, regOrder);
687-
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder,
688-
kDimIdx, kLane);
689-
auto warpsLayout =
690-
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp);
691-
692-
LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) *
693-
lanesLayout.transposeOuts(repDimNames) *
694-
warpsLayout.transposeOuts(repDimNames);
695-
696-
return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
697-
}
698-
699654
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
700655
unsigned kWidth, ArrayRef<unsigned> order,
701656
ArrayRef<unsigned> repOrder) {
@@ -786,21 +741,19 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
786741
auto ctaLayout =
787742
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
788743
auto kDim = isA ? rank - 1 : rank - 2;
789-
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(),
790-
mma.getWarpOrder(), kDim, S("warp"))
791-
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
744+
ctaLayout *=
745+
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim)
746+
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
792747

793748
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
794749
}
795750

796751
std::optional<LinearLayout>
797752
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
798753
auto parent = getParent();
799-
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
800-
return fmaDotToLinearLayout(*this, shape);
801-
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
754+
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
802755
return mfmaDotToLinearLayout(*this, shape);
803-
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
756+
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
804757
return wmmaDotOperandToLinearLayout(*this, shape);
805758
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
806759
return nvidiaDotToLinearLayout(shape, *this);

test/Conversion/amd/decompose-unsupported-conversions.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
9797
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
9898
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
9999
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
100-
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
100+
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
101101
// CHECK-NOT: ttg.convert_layout
102102
// CHECK: ttg.local_alloc
103103
// CHECK: ttg.local_load
104-
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
104+
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
105105
tt.return
106106
}
107107
}

0 commit comments

Comments
 (0)