Skip to content

Commit a14e141

Browse files
authored
Update FMA code (#4613)
Just align code with triton upstream. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 248c431 commit a14e141

File tree

2 files changed

+51
-109
lines changed

2 files changed

+51
-109
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ using ::mlir::triton::gpu::getShapePerCTA;
77
using ::mlir::triton::gpu::intel::DpasEncodingAttr;
88

99
namespace fma_details {
10-
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
11-
TritonIntelGPUToLLVMTypeConverter *typeConverter,
12-
ConversionPatternRewriter &rewriter);
10+
LogicalResult
11+
convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
12+
const TritonIntelGPUToLLVMTypeConverter *typeConverter,
13+
ConversionPatternRewriter &rewriter);
1314

1415
LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1516
TritonIntelGPUToLLVMTypeConverter *typeConverter,
Lines changed: 47 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,60 @@
1-
#include "../TritonGPUToLLVMBase.h"
1+
#include "TritonIntelGPUToLLVM/TypeConverter.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
3+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
24
#include "llvm/ADT/TypeSwitch.h"
35

46
using namespace mlir;
57
using namespace mlir::triton;
68
using namespace ::mlir::triton::gpu;
79

8-
using ::mlir::LLVM::linearize;
9-
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
10-
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
11-
using ::mlir::triton::gpu::getShapePerCTA;
12-
13-
using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
14-
15-
static ValueTableFMA
16-
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
17-
unsigned kDim, unsigned nonKDim,
18-
ConversionPatternRewriter &rewriter, Location loc,
19-
ArrayRef<unsigned> order) {
20-
ValueTableFMA res;
21-
auto elems = unpackLLElements(loc, val, rewriter);
22-
assert(perTileShape.size() == 3);
23-
assert(elems.size() == product(perTileShape));
24-
assert(kDim == 1 || kDim == 2);
25-
assert(nonKDim == 1 || nonKDim == 2);
26-
const unsigned bDim = 0;
27-
28-
for (unsigned idx = 0; idx < elems.size(); ++idx) {
29-
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
30-
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
10+
namespace {
11+
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
12+
OpBuilder &builder;
13+
Location loc;
14+
15+
public:
16+
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
17+
: builder(builder), loc(loc) {}
18+
19+
Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
20+
Value c) override {
21+
auto K = a.size();
22+
assert(b.size() == K);
23+
Value accum = c;
24+
Type tgtTy = accum.getType();
25+
for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) {
26+
const auto &aElem = std::get<0>(*it);
27+
const auto &bElem = std::get<1>(*it);
28+
29+
assert(aElem.getType() == tgtTy);
30+
assert(bElem.getType() == tgtTy);
31+
32+
llvm::TypeSwitch<Type>(tgtTy)
33+
.Case<FloatType>([&](auto) {
34+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
35+
})
36+
.Case<IntegerType>([&](auto) {
37+
accum = builder.create<LLVM::AddOp>(
38+
loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum);
39+
});
40+
}
41+
return accum;
3142
}
32-
return res;
33-
}
43+
};
44+
45+
} // namespace
3446

3547
namespace fma_details {
36-
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
37-
TritonIntelGPUToLLVMTypeConverter *typeConverter,
38-
ConversionPatternRewriter &rewriter) {
48+
49+
LogicalResult
50+
convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
51+
const TritonIntelGPUToLLVMTypeConverter *typeConverter,
52+
ConversionPatternRewriter &rewriter) {
3953
auto *ctx = rewriter.getContext();
4054
auto loc = op.getLoc();
41-
42-
auto A = op.getA();
43-
auto D = op.getResult();
44-
45-
auto aTensorTy = cast<RankedTensorType>(A.getType());
46-
auto dTensorTy = cast<RankedTensorType>(D.getType());
47-
48-
SmallVector<int64_t> aShapePerCTA =
49-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
50-
auto dShapePerCTA =
51-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));
52-
53-
BlockedEncodingAttr dLayout =
54-
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
55-
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
56-
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
57-
58-
Value llA = adaptor.getA();
59-
Value llB = adaptor.getB();
60-
61-
auto sizePerThread = getContigPerThread(dTensorTy);
62-
SmallVector<unsigned> shapePerCTATile;
63-
for (auto [reg, thread, warp] :
64-
llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(),
65-
dLayout.getWarpsPerCTA())) {
66-
shapePerCTATile.push_back(reg * thread * warp);
67-
}
68-
shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile));
69-
sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread));
70-
71-
unsigned K = aShapePerCTA[2];
72-
73-
unsigned perThreadShape[3];
74-
for (int i = 0; i < 3; ++i) {
75-
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
76-
numRep = std::max(static_cast<unsigned>(1), numRep);
77-
perThreadShape[i] = numRep * sizePerThread[i];
78-
}
79-
80-
auto has = getValueTableFromStructFMA(
81-
llA, {perThreadShape[0], perThreadShape[1], K},
82-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
83-
auto hbs = getValueTableFromStructFMA(
84-
llB, {perThreadShape[0], K, perThreadShape[2]},
85-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
86-
87-
SmallVector<Value> acc = cc;
88-
89-
for (unsigned b = 0; b < perThreadShape[0]; ++b)
90-
for (unsigned m = 0; m < perThreadShape[1]; ++m)
91-
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
92-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
93-
unsigned linearAccumIdx =
94-
linearize(multiDimAccumIdx, perThreadShape, order);
95-
for (unsigned k = 0; k < K; ++k) {
96-
Type tgtTy = acc[linearAccumIdx].getType();
97-
Value opA = has[{b, m, k}];
98-
Value opB = hbs[{b, n, k}];
99-
assert(opA.getType() == tgtTy);
100-
assert(opB.getType() == tgtTy);
101-
llvm::TypeSwitch<Type>(tgtTy)
102-
.Case<FloatType>([&](auto) {
103-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
104-
loc, opA, opB, acc[linearAccumIdx]);
105-
})
106-
.Case<IntegerType>([&](auto) {
107-
acc[linearAccumIdx] = rewriter.create<LLVM::AddOp>(
108-
loc, rewriter.create<LLVM::MulOp>(loc, opA, opB),
109-
acc[linearAccumIdx]);
110-
});
111-
}
112-
}
113-
114-
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
115-
rewriter.replaceOp(op, res);
116-
117-
return success();
55+
GenericFMAVectorMultiplier multiplier(rewriter, loc);
56+
return parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
57+
multiplier);
11858
}
59+
11960
} // namespace fma_details

0 commit comments

Comments
 (0)