Skip to content

Commit daae522

Browse files
authored
[NFI]: Prepare code for FMA loop generation feature (#5160)
This PR prepares the codebase for FMA loop generation by refactoring dot operation conversion code. The main purpose is to extract and organize FMA-related functionality into separate modules for better code organization and reusability. --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 5bd23c9 commit daae522

File tree

5 files changed

+254
-7
lines changed

5 files changed

+254
-7
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ add_triton_library(TritonIntelGPUToLLVM
55
ControlFlowOpToLLVM.cpp
66
ConvertLayoutOpToLLVM.cpp
77
DotOpToLLVM/DPAS.cpp
8+
DotOpToLLVM/FMA.cpp
9+
DotOpToLLVM/FMADotUtility.cpp
810
DotOpToLLVM.cpp
911
ElementwiseOpToLLVM.cpp
1012
Fp4ToFpOpToLLVM.cpp

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ using namespace mlir::triton;
66
using ::mlir::triton::gpu::getShapePerCTA;
77
using ::mlir::triton::gpu::intel::DpasEncodingAttr;
88

9-
namespace fma_details {
9+
namespace mlir::triton::gpu::intel {
10+
11+
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
12+
const LLVMTypeConverter *typeConverter,
13+
ConversionPatternRewriter &rewriter);
14+
1015
LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
1116
TritonIntelGPUToLLVMTypeConverter *typeConverter,
1217
ConversionPatternRewriter &rewriter);
13-
} // namespace fma_details
18+
19+
} // namespace mlir::triton::gpu::intel
1420

1521
namespace {
1622
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
@@ -33,13 +39,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
3339

3440
if (!isOuter && isa<DpasEncodingAttr>(
3541
cast<RankedTensorType>(D.getType()).getEncoding())) {
36-
return fma_details::convertDPAS(op, adaptor, getTypeConverter(),
37-
rewriter);
42+
return triton::gpu::intel::convertDPAS(op, adaptor, getTypeConverter(),
43+
rewriter);
3844
}
3945

4046
if (isa<BlockedEncodingAttr>(
4147
cast<RankedTensorType>(D.getType()).getEncoding()))
42-
return convertFMADot(op, adaptor, getTypeConverter(), rewriter);
48+
return triton::gpu::intel::convertFMADot(op, adaptor, getTypeConverter(),
49+
rewriter);
4350

4451
llvm::report_fatal_error(
4552
"Unsupported DotOp found when converting TritonGPU to LLVM.");

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ class DotOpDPASConversionHelper {
406406

407407
} // namespace
408408

409-
namespace fma_details {
409+
namespace mlir::triton::gpu::intel {
410+
410411
LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
411412
TritonIntelGPUToLLVMTypeConverter *typeConverter,
412413
ConversionPatternRewriter &rewriter) {
@@ -441,4 +442,5 @@ LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
441442

442443
return helper.convertDot(op, adaptor);
443444
}
444-
} // namespace fma_details
445+
446+
} // namespace mlir::triton::gpu::intel
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "llvm/ADT/TypeSwitch.h"
4+
5+
using namespace mlir;
6+
using namespace mlir::triton;
7+
using namespace ::mlir::triton::gpu;
8+
9+
namespace {
10+
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
11+
OpBuilder &builder;
12+
Location loc;
13+
14+
public:
15+
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
16+
: builder(builder), loc(loc) {}
17+
18+
Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
19+
Value c) override {
20+
auto K = a.size();
21+
assert(b.size() == K);
22+
Value accum = c;
23+
Type tgtTy = accum.getType();
24+
for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) {
25+
const auto &aElem = std::get<0>(*it);
26+
const auto &bElem = std::get<1>(*it);
27+
28+
assert(aElem.getType() == tgtTy);
29+
assert(bElem.getType() == tgtTy);
30+
31+
// to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM
32+
// type or LLVM dialect-compatible vector of floating point LLVM type, but
33+
// got 'i32'
34+
llvm::TypeSwitch<Type>(tgtTy)
35+
.Case<FloatType>([&](auto) {
36+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
37+
})
38+
.Case<IntegerType>([&](auto) {
39+
accum = builder.create<LLVM::AddOp>(
40+
loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum);
41+
});
42+
}
43+
return accum;
44+
}
45+
};
46+
47+
} // namespace
48+
49+
namespace mlir::triton::gpu::intel {
50+
51+
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
52+
const LLVMTypeConverter *typeConverter,
53+
ConversionPatternRewriter &rewriter,
54+
FMAVectorMultiplier &multiplier);
55+
56+
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
57+
const LLVMTypeConverter *typeConverter,
58+
ConversionPatternRewriter &rewriter) {
59+
auto *ctx = rewriter.getContext();
60+
auto loc = op.getLoc();
61+
GenericFMAVectorMultiplier multiplier(rewriter, loc);
62+
return intel::parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
63+
multiplier);
64+
}
65+
66+
} // namespace mlir::triton::gpu::intel
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
4+
using namespace mlir;
5+
6+
namespace {
7+
8+
/// OperandValueKey structure represents compile time part
9+
/// of spatial coordinates of a value in a tensor.
10+
///
11+
/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be
12+
/// defined as:
13+
///
14+
/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord)
15+
/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord)
16+
/// k = kIdx
17+
///
18+
/// Where:
19+
/// CTABSize, CTANKSize: constants;
20+
/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components;
21+
/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components.
22+
struct OperandValueKey {
23+
unsigned bRepIdx, nonKRepIdx;
24+
unsigned bIdx, nonKIdx, kIdx;
25+
26+
bool operator==(const OperandValueKey &other) const {
27+
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
28+
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
29+
kIdx == other.kIdx);
30+
}
31+
};
32+
33+
} // namespace
34+
35+
template <> struct std::hash<OperandValueKey> {
36+
std::size_t operator()(const OperandValueKey &k) const {
37+
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
38+
k.kIdx);
39+
}
40+
};
41+
42+
namespace {
43+
44+
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
45+
46+
ValueTableFMA getValueTableFromStructFMA(
47+
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
48+
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
49+
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
50+
ValueTableFMA res;
51+
auto elems = unpackLLElements(loc, val, rewriter);
52+
assert(perRepShape.size() == 3);
53+
auto numElemsRep = product(perRepShape);
54+
assert(elems.size() == numElemsRep * product(repetitions));
55+
assert(kDim == 1 || kDim == 2);
56+
assert(nonKDim == 1 || nonKDim == 2);
57+
const unsigned bDim = 0;
58+
59+
for (unsigned idx = 0; idx < elems.size(); ++idx) {
60+
auto inRepLinearIdx = idx % numElemsRep;
61+
auto repLinearIdx = idx / numElemsRep;
62+
auto inRepSpatialIdx =
63+
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
64+
auto repSpatialIdx =
65+
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
66+
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
67+
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
68+
inRepSpatialIdx[kDim]};
69+
res[key] = elems[idx];
70+
}
71+
return res;
72+
}
73+
74+
} // namespace
75+
76+
namespace mlir::triton::gpu::intel {
77+
78+
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
79+
const LLVMTypeConverter *typeConverter,
80+
ConversionPatternRewriter &rewriter,
81+
FMAVectorMultiplier &multiplier) {
82+
auto *ctx = rewriter.getContext();
83+
auto loc = op.getLoc();
84+
85+
auto A = op.getA();
86+
auto D = op.getResult();
87+
88+
auto aTensorTy = cast<RankedTensorType>(A.getType());
89+
auto dTensorTy = cast<RankedTensorType>(D.getType());
90+
91+
SmallVector<int64_t> aShapePerCTA =
92+
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
93+
auto dShapePerCTA =
94+
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));
95+
96+
BlockedEncodingAttr dLayout =
97+
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
98+
// TODO process A and B operand separately
99+
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
100+
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
101+
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
102+
103+
Value llA = adaptor.getA();
104+
Value llB = adaptor.getB();
105+
106+
auto sizePerThread = getContigPerThread(dTensorTy);
107+
auto numElemsPerThread = product(sizePerThread);
108+
SmallVector<unsigned> shapePerCTATile;
109+
for (auto [reg, thread, warp] :
110+
llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(),
111+
dLayout.getWarpsPerCTA())) {
112+
shapePerCTATile.push_back(reg * thread * warp);
113+
}
114+
shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile));
115+
sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread));
116+
117+
unsigned K = aShapePerCTA[2];
118+
119+
unsigned threadTileShape[3];
120+
unsigned repetitions[3];
121+
for (int i = 0; i < 3; ++i) {
122+
repetitions[i] =
123+
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
124+
}
125+
126+
auto has = getValueTableFromStructFMA(
127+
llA, {sizePerThread[0], sizePerThread[1], K},
128+
{repetitions[0], repetitions[1], 1},
129+
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
130+
auto hbs = getValueTableFromStructFMA(
131+
llB, {sizePerThread[0], K, sizePerThread[2]},
132+
{repetitions[0], 1, repetitions[2]},
133+
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
134+
135+
SmallVector<Value> acc = cc;
136+
137+
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
138+
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
139+
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
140+
for (unsigned b = 0; b < sizePerThread[0]; ++b)
141+
for (unsigned m = 0; m < sizePerThread[1]; ++m)
142+
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
143+
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
144+
unsigned linearInRepIdx =
145+
LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
146+
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
147+
unsigned linearRepIdx =
148+
LLVM::linearize(multiDimRepIdx, repetitions, repOrder);
149+
unsigned linearAccumIdx =
150+
linearInRepIdx + linearRepIdx * numElemsPerThread;
151+
152+
SmallVector<Value> aOpVector;
153+
SmallVector<Value> bOpVector;
154+
155+
for (unsigned k = 0; k < K; ++k) {
156+
aOpVector.push_back(has.at({bRep, mRep, b, m, k}));
157+
bOpVector.push_back(hbs.at({bRep, nRep, b, n, k}));
158+
}
159+
160+
acc[linearAccumIdx] = multiplier.multiplyVectors(
161+
aOpVector, bOpVector, acc[linearAccumIdx]);
162+
}
163+
164+
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
165+
rewriter.replaceOp(op, res);
166+
167+
return success();
168+
}
169+
170+
} // namespace mlir::triton::gpu::intel

0 commit comments

Comments
 (0)