Skip to content

Commit 0753712

Browse files
authored
[backend] NFC: Split architecture dependant and independant parts of FMA dot conversion (#5655)
This PR splits FMA dot conversion from Triton GPU to LLVM in two parts: - Common code with iteration across M/N dim - Architecture dependant scalar multiplication of vectos across K dim This PR do not introduce any test, because it does not fix any bugs or introduce new functionality, it just refactors code.
1 parent 9a49104 commit 0753712

File tree

4 files changed

+224
-129
lines changed

4 files changed

+224
-129
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H
2+
#define TRITON_CONVERSION_FMA_DOT_UTILITY_H
3+
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "mlir/Support/LLVM.h"
6+
#include "mlir/Transforms/DialectConversion.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
9+
namespace mlir::triton::gpu {
10+
11+
/// Abstract interface for scalar multiplication of Value vectors.
12+
///
13+
/// Enable generation of hardware specific code in different backends.
14+
class FMAVectorMultiplier {
15+
public:
16+
/// \returns scalar product of two arrays, plus c: a·b + c
17+
virtual Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
18+
Value c) = 0;
19+
20+
virtual ~FMAVectorMultiplier() = default;
21+
};
22+
23+
/// Implements a framework for FMA dot conversion to llvm.
24+
///
25+
/// This function implements architecture independent part of FMA dot
26+
/// conversion and calls "multiplier" object, which is defined by caller
27+
/// and implements architecture dependant part of conversion.
28+
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
29+
const LLVMTypeConverter *typeConverter,
30+
ConversionPatternRewriter &rewriter,
31+
FMAVectorMultiplier &multiplier);
32+
33+
} // namespace mlir::triton::gpu
34+
35+
#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonGPUToLLVM
22
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
33
DotOpToLLVM/FMA.cpp
4+
DotOpToLLVM/FMADotUtility.cpp
45
AllocateSharedMemory.cpp
56
AssertOpToLLVM.cpp
67
ControlFlowOpToLLVM.cpp
Lines changed: 23 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,38 @@
1-
#include "mlir/Support/LLVM.h"
1+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4-
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
53

64
using namespace mlir;
75
using namespace mlir::triton;
86
using namespace ::mlir::triton::gpu;
97

10-
using ::mlir::LLVM::linearize;
11-
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
12-
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
13-
using ::mlir::triton::gpu::getShapePerCTA;
14-
using ::mlir::triton::gpu::getSizePerThread;
15-
16-
/// \brief spatial position of repetition and register of a given value
17-
struct OperandValueKey {
18-
unsigned bRepIdx, nonKRepIdx;
19-
unsigned bIdx, nonKIdx, kIdx;
20-
21-
bool operator==(const OperandValueKey &other) const {
22-
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23-
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24-
kIdx == other.kIdx);
25-
}
26-
};
27-
28-
template <> struct std::hash<OperandValueKey> {
29-
std::size_t operator()(const OperandValueKey &k) const {
30-
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
31-
k.kIdx);
8+
namespace {
9+
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
10+
OpBuilder &builder;
11+
Location loc;
12+
13+
public:
14+
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
15+
: builder(builder), loc(loc) {}
16+
17+
Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
18+
Value c) override {
19+
auto K = a.size();
20+
assert(b.size() == K);
21+
Value accum = c;
22+
for (auto [aElem, bElem] : llvm::zip(a, b))
23+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
24+
return accum;
3225
}
3326
};
3427

35-
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
36-
37-
static ValueTableFMA getValueTableFromStructFMA(
38-
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
39-
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
40-
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
41-
ValueTableFMA res;
42-
auto elems = unpackLLElements(loc, val, rewriter);
43-
assert(perRepShape.size() == 3);
44-
auto numElemsRep = product(perRepShape);
45-
assert(elems.size() == numElemsRep * product(repetitions));
46-
assert(kDim == 1 || kDim == 2);
47-
assert(nonKDim == 1 || nonKDim == 2);
48-
const unsigned bDim = 0;
28+
} // namespace
4929

50-
for (unsigned idx = 0; idx < elems.size(); ++idx) {
51-
auto inRepLinearIdx = idx % numElemsRep;
52-
auto repLinearIdx = idx / numElemsRep;
53-
auto inRepSpatialIdx =
54-
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
55-
auto repSpatialIdx =
56-
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
57-
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
58-
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
59-
inRepSpatialIdx[kDim]};
60-
res[key] = elems[idx];
61-
}
62-
return res;
63-
}
64-
65-
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
30+
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
6631
const LLVMTypeConverter *typeConverter,
6732
ConversionPatternRewriter &rewriter) {
6833
auto *ctx = rewriter.getContext();
6934
auto loc = op.getLoc();
70-
71-
auto A = op.getA();
72-
auto D = op.getResult();
73-
74-
auto aTensorTy = cast<RankedTensorType>(A.getType());
75-
auto dTensorTy = cast<RankedTensorType>(D.getType());
76-
77-
SmallVector<int64_t> aShapePerCTA =
78-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
79-
auto dShapePerCTA =
80-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));
81-
82-
BlockedEncodingAttr dLayout =
83-
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
84-
// TODO process A and B operand separately
85-
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
86-
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
87-
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
88-
89-
Value llA = adaptor.getA();
90-
Value llB = adaptor.getB();
91-
92-
auto sizePerThread =
93-
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
94-
auto numElemsPerThread = product(sizePerThread);
95-
auto shapePerCTATile =
96-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
97-
98-
unsigned K = aShapePerCTA[2];
99-
100-
unsigned threadTileShape[3];
101-
unsigned repetitions[3];
102-
for (int i = 0; i < 3; ++i) {
103-
repetitions[i] =
104-
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
105-
}
106-
107-
auto has = getValueTableFromStructFMA(
108-
llA, {sizePerThread[0], sizePerThread[1], K},
109-
{repetitions[0], repetitions[1], 1},
110-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
111-
auto hbs = getValueTableFromStructFMA(
112-
llB, {sizePerThread[0], K, sizePerThread[2]},
113-
{repetitions[0], 1, repetitions[2]},
114-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
115-
116-
SmallVector<Value> acc = cc;
117-
118-
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
119-
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
120-
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
121-
for (unsigned b = 0; b < sizePerThread[0]; ++b)
122-
for (unsigned m = 0; m < sizePerThread[1]; ++m)
123-
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
124-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
125-
unsigned linearInRepIdx =
126-
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
127-
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
128-
unsigned linearRepIdx =
129-
linearize(multiDimRepIdx, repetitions, repOrder);
130-
unsigned linearAccumIdx =
131-
linearInRepIdx + linearRepIdx * numElemsPerThread;
132-
for (unsigned k = 0; k < K; ++k) {
133-
auto aOp = has[{bRep, mRep, b, m, k}];
134-
auto bOp = hbs[{bRep, nRep, b, n, k}];
135-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
136-
loc, aOp, bOp, acc[linearAccumIdx]);
137-
}
138-
}
139-
140-
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
141-
rewriter.replaceOp(op, res);
142-
143-
return success();
35+
GenericFMAVectorMultiplier multiplier(rewriter, loc);
36+
return parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
37+
multiplier);
14438
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
4+
using namespace mlir;
5+
6+
namespace {
7+
8+
/// OperandValueKey structure represents compile time part
9+
/// of spatial coordinates of a value in a tensor.
10+
///
11+
/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be
12+
/// defined as:
13+
///
14+
/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord)
15+
/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord)
16+
/// k = kIdx
17+
///
18+
/// Where:
19+
/// CTABSize, CTANKSize: constants;
20+
/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components;
21+
/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components.
22+
struct OperandValueKey {
23+
unsigned bRepIdx, nonKRepIdx;
24+
unsigned bIdx, nonKIdx, kIdx;
25+
26+
bool operator==(const OperandValueKey &other) const {
27+
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
28+
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
29+
kIdx == other.kIdx);
30+
}
31+
};
32+
33+
} // namespace
34+
35+
template <> struct std::hash<OperandValueKey> {
36+
std::size_t operator()(const OperandValueKey &k) const {
37+
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
38+
k.kIdx);
39+
}
40+
};
41+
42+
namespace {
43+
44+
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
45+
46+
ValueTableFMA getValueTableFromStructFMA(
47+
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
48+
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
49+
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
50+
ValueTableFMA res;
51+
auto elems = unpackLLElements(loc, val, rewriter);
52+
assert(perRepShape.size() == 3);
53+
auto numElemsRep = product(perRepShape);
54+
assert(elems.size() == numElemsRep * product(repetitions));
55+
assert(kDim == 1 || kDim == 2);
56+
assert(nonKDim == 1 || nonKDim == 2);
57+
const unsigned bDim = 0;
58+
59+
for (unsigned idx = 0; idx < elems.size(); ++idx) {
60+
auto inRepLinearIdx = idx % numElemsRep;
61+
auto repLinearIdx = idx / numElemsRep;
62+
auto inRepSpatialIdx =
63+
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
64+
auto repSpatialIdx =
65+
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
66+
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
67+
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
68+
inRepSpatialIdx[kDim]};
69+
res[key] = elems[idx];
70+
}
71+
return res;
72+
}
73+
74+
} // namespace
75+
76+
namespace mlir::triton::gpu {
77+
78+
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
79+
const LLVMTypeConverter *typeConverter,
80+
ConversionPatternRewriter &rewriter,
81+
FMAVectorMultiplier &multiplier) {
82+
auto *ctx = rewriter.getContext();
83+
auto loc = op.getLoc();
84+
85+
auto A = op.getA();
86+
auto D = op.getResult();
87+
88+
auto aTensorTy = cast<RankedTensorType>(A.getType());
89+
auto dTensorTy = cast<RankedTensorType>(D.getType());
90+
91+
SmallVector<int64_t> aShapePerCTA =
92+
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
93+
auto dShapePerCTA =
94+
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));
95+
96+
BlockedEncodingAttr dLayout =
97+
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
98+
// TODO process A and B operand separately
99+
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
100+
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
101+
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
102+
103+
Value llA = adaptor.getA();
104+
Value llB = adaptor.getB();
105+
106+
auto sizePerThread =
107+
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
108+
auto numElemsPerThread = product(sizePerThread);
109+
auto shapePerCTATile =
110+
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
111+
112+
unsigned K = aShapePerCTA[2];
113+
114+
unsigned threadTileShape[3];
115+
unsigned repetitions[3];
116+
for (int i = 0; i < 3; ++i) {
117+
repetitions[i] =
118+
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
119+
}
120+
121+
auto has = getValueTableFromStructFMA(
122+
llA, {sizePerThread[0], sizePerThread[1], K},
123+
{repetitions[0], repetitions[1], 1},
124+
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
125+
auto hbs = getValueTableFromStructFMA(
126+
llB, {sizePerThread[0], K, sizePerThread[2]},
127+
{repetitions[0], 1, repetitions[2]},
128+
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
129+
130+
SmallVector<Value> acc = cc;
131+
132+
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
133+
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
134+
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
135+
for (unsigned b = 0; b < sizePerThread[0]; ++b)
136+
for (unsigned m = 0; m < sizePerThread[1]; ++m)
137+
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
138+
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
139+
unsigned linearInRepIdx =
140+
LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
141+
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
142+
unsigned linearRepIdx =
143+
LLVM::linearize(multiDimRepIdx, repetitions, repOrder);
144+
unsigned linearAccumIdx =
145+
linearInRepIdx + linearRepIdx * numElemsPerThread;
146+
147+
SmallVector<Value> aOpVector;
148+
SmallVector<Value> bOpVector;
149+
150+
for (unsigned k = 0; k < K; ++k) {
151+
aOpVector.push_back(has.at({bRep, mRep, b, m, k}));
152+
bOpVector.push_back(hbs.at({bRep, nRep, b, n, k}));
153+
}
154+
155+
acc[linearAccumIdx] = multiplier.multiplyVectors(
156+
aOpVector, bOpVector, acc[linearAccumIdx]);
157+
}
158+
159+
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
160+
rewriter.replaceOp(op, res);
161+
162+
return success();
163+
}
164+
165+
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)