|
1 |
| -#include "../TritonGPUToLLVMBase.h" |
| 1 | +#include "TritonIntelGPUToLLVM/TypeConverter.h" |
| 2 | +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" |
| 3 | +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" |
2 | 4 | #include "llvm/ADT/TypeSwitch.h"
|
3 | 5 |
|
4 | 6 | using namespace mlir;
|
5 | 7 | using namespace mlir::triton;
|
6 | 8 | using namespace ::mlir::triton::gpu;
|
7 | 9 |
|
8 |
| -using ::mlir::LLVM::linearize; |
9 |
| -using ::mlir::triton::gpu::expandMatrixOrderWithBatch; |
10 |
| -using ::mlir::triton::gpu::expandMatrixShapeWithBatch; |
11 |
| -using ::mlir::triton::gpu::getShapePerCTA; |
12 |
| - |
13 |
| -using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>; |
14 |
| - |
15 |
| -static ValueTableFMA |
16 |
| -getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape, |
17 |
| - unsigned kDim, unsigned nonKDim, |
18 |
| - ConversionPatternRewriter &rewriter, Location loc, |
19 |
| - ArrayRef<unsigned> order) { |
20 |
| - ValueTableFMA res; |
21 |
| - auto elems = unpackLLElements(loc, val, rewriter); |
22 |
| - assert(perTileShape.size() == 3); |
23 |
| - assert(elems.size() == product(perTileShape)); |
24 |
| - assert(kDim == 1 || kDim == 2); |
25 |
| - assert(nonKDim == 1 || nonKDim == 2); |
26 |
| - const unsigned bDim = 0; |
27 |
| - |
28 |
| - for (unsigned idx = 0; idx < elems.size(); ++idx) { |
29 |
| - auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order); |
30 |
| - res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx]; |
| 10 | +namespace { |
| 11 | +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { |
| 12 | + OpBuilder &builder; |
| 13 | + Location loc; |
| 14 | + |
| 15 | +public: |
| 16 | + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) |
| 17 | + : builder(builder), loc(loc) {} |
| 18 | + |
| 19 | + Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b, |
| 20 | + Value c) override { |
| 21 | + auto K = a.size(); |
| 22 | + assert(b.size() == K); |
| 23 | + Value accum = c; |
| 24 | + Type tgtTy = accum.getType(); |
| 25 | + for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) { |
| 26 | + const auto &aElem = std::get<0>(*it); |
| 27 | + const auto &bElem = std::get<1>(*it); |
| 28 | + |
| 29 | + assert(aElem.getType() == tgtTy); |
| 30 | + assert(bElem.getType() == tgtTy); |
| 31 | + |
| 32 | + llvm::TypeSwitch<Type>(tgtTy) |
| 33 | + .Case<FloatType>([&](auto) { |
| 34 | + accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum); |
| 35 | + }) |
| 36 | + .Case<IntegerType>([&](auto) { |
| 37 | + accum = builder.create<LLVM::AddOp>( |
| 38 | + loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum); |
| 39 | + }); |
| 40 | + } |
| 41 | + return accum; |
31 | 42 | }
|
32 |
| - return res; |
33 |
| -} |
| 43 | +}; |
| 44 | + |
| 45 | +} // namespace |
34 | 46 |
|
35 | 47 | namespace fma_details {
|
36 |
| -LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, |
37 |
| - TritonIntelGPUToLLVMTypeConverter *typeConverter, |
38 |
| - ConversionPatternRewriter &rewriter) { |
| 48 | + |
| 49 | +LogicalResult |
| 50 | +convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, |
| 51 | + const TritonIntelGPUToLLVMTypeConverter *typeConverter, |
| 52 | + ConversionPatternRewriter &rewriter) { |
39 | 53 | auto *ctx = rewriter.getContext();
|
40 | 54 | auto loc = op.getLoc();
|
41 |
| - |
42 |
| - auto A = op.getA(); |
43 |
| - auto D = op.getResult(); |
44 |
| - |
45 |
| - auto aTensorTy = cast<RankedTensorType>(A.getType()); |
46 |
| - auto dTensorTy = cast<RankedTensorType>(D.getType()); |
47 |
| - |
48 |
| - SmallVector<int64_t> aShapePerCTA = |
49 |
| - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); |
50 |
| - auto dShapePerCTA = |
51 |
| - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); |
52 |
| - |
53 |
| - BlockedEncodingAttr dLayout = |
54 |
| - cast<BlockedEncodingAttr>(dTensorTy.getEncoding()); |
55 |
| - auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); |
56 |
| - auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); |
57 |
| - |
58 |
| - Value llA = adaptor.getA(); |
59 |
| - Value llB = adaptor.getB(); |
60 |
| - |
61 |
| - auto sizePerThread = getContigPerThread(dTensorTy); |
62 |
| - SmallVector<unsigned> shapePerCTATile; |
63 |
| - for (auto [reg, thread, warp] : |
64 |
| - llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), |
65 |
| - dLayout.getWarpsPerCTA())) { |
66 |
| - shapePerCTATile.push_back(reg * thread * warp); |
67 |
| - } |
68 |
| - shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); |
69 |
| - sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); |
70 |
| - |
71 |
| - unsigned K = aShapePerCTA[2]; |
72 |
| - |
73 |
| - unsigned perThreadShape[3]; |
74 |
| - for (int i = 0; i < 3; ++i) { |
75 |
| - unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; |
76 |
| - numRep = std::max(static_cast<unsigned>(1), numRep); |
77 |
| - perThreadShape[i] = numRep * sizePerThread[i]; |
78 |
| - } |
79 |
| - |
80 |
| - auto has = getValueTableFromStructFMA( |
81 |
| - llA, {perThreadShape[0], perThreadShape[1], K}, |
82 |
| - /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order); |
83 |
| - auto hbs = getValueTableFromStructFMA( |
84 |
| - llB, {perThreadShape[0], K, perThreadShape[2]}, |
85 |
| - /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order); |
86 |
| - |
87 |
| - SmallVector<Value> acc = cc; |
88 |
| - |
89 |
| - for (unsigned b = 0; b < perThreadShape[0]; ++b) |
90 |
| - for (unsigned m = 0; m < perThreadShape[1]; ++m) |
91 |
| - for (unsigned n = 0; n < perThreadShape[2]; ++n) { |
92 |
| - SmallVector<unsigned> multiDimAccumIdx = {b, m, n}; |
93 |
| - unsigned linearAccumIdx = |
94 |
| - linearize(multiDimAccumIdx, perThreadShape, order); |
95 |
| - for (unsigned k = 0; k < K; ++k) { |
96 |
| - Type tgtTy = acc[linearAccumIdx].getType(); |
97 |
| - Value opA = has[{b, m, k}]; |
98 |
| - Value opB = hbs[{b, n, k}]; |
99 |
| - assert(opA.getType() == tgtTy); |
100 |
| - assert(opB.getType() == tgtTy); |
101 |
| - llvm::TypeSwitch<Type>(tgtTy) |
102 |
| - .Case<FloatType>([&](auto) { |
103 |
| - acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>( |
104 |
| - loc, opA, opB, acc[linearAccumIdx]); |
105 |
| - }) |
106 |
| - .Case<IntegerType>([&](auto) { |
107 |
| - acc[linearAccumIdx] = rewriter.create<LLVM::AddOp>( |
108 |
| - loc, rewriter.create<LLVM::MulOp>(loc, opA, opB), |
109 |
| - acc[linearAccumIdx]); |
110 |
| - }); |
111 |
| - } |
112 |
| - } |
113 |
| - |
114 |
| - auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); |
115 |
| - rewriter.replaceOp(op, res); |
116 |
| - |
117 |
| - return success(); |
| 55 | + GenericFMAVectorMultiplier multiplier(rewriter, loc); |
| 56 | + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, |
| 57 | + multiplier); |
118 | 58 | }
|
| 59 | + |
119 | 60 | } // namespace fma_details
|
0 commit comments