| 
 | 1 | +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"  | 
 | 2 | +#include "triton/Conversion/TritonGPUToLLVM/Utility.h"  | 
 | 3 | + | 
 | 4 | +using namespace mlir;  | 
 | 5 | + | 
 | 6 | +namespace {  | 
 | 7 | + | 
 | 8 | +/// OperandValueKey structure represents compile time part  | 
 | 9 | +/// of spatial coordinates of a value in a tensor.  | 
 | 10 | +///  | 
 | 11 | +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be  | 
 | 12 | +/// defined as:  | 
 | 13 | +///  | 
 | 14 | +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord)  | 
 | 15 | +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord)  | 
 | 16 | +/// k = kIdx  | 
 | 17 | +///  | 
 | 18 | +/// Where:  | 
 | 19 | +/// CTABSize, CTANKSize: constants;  | 
 | 20 | +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components;  | 
 | 21 | +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components.  | 
 | 22 | +struct OperandValueKey {  | 
 | 23 | +  unsigned bRepIdx, nonKRepIdx;  | 
 | 24 | +  unsigned bIdx, nonKIdx, kIdx;  | 
 | 25 | + | 
 | 26 | +  bool operator==(const OperandValueKey &other) const {  | 
 | 27 | +    return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&  | 
 | 28 | +            bIdx == other.bIdx && nonKIdx == other.nonKIdx &&  | 
 | 29 | +            kIdx == other.kIdx);  | 
 | 30 | +  }  | 
 | 31 | +};  | 
 | 32 | + | 
 | 33 | +} // namespace  | 
 | 34 | + | 
 | 35 | +template <> struct std::hash<OperandValueKey> {  | 
 | 36 | +  std::size_t operator()(const OperandValueKey &k) const {  | 
 | 37 | +    return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,  | 
 | 38 | +                              k.kIdx);  | 
 | 39 | +  }  | 
 | 40 | +};  | 
 | 41 | + | 
 | 42 | +namespace {  | 
 | 43 | + | 
 | 44 | +using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;  | 
 | 45 | + | 
 | 46 | +ValueTableFMA getValueTableFromStructFMA(  | 
 | 47 | +    Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,  | 
 | 48 | +    unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,  | 
 | 49 | +    Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {  | 
 | 50 | +  ValueTableFMA res;  | 
 | 51 | +  auto elems = unpackLLElements(loc, val, rewriter);  | 
 | 52 | +  assert(perRepShape.size() == 3);  | 
 | 53 | +  auto numElemsRep = product(perRepShape);  | 
 | 54 | +  assert(elems.size() == numElemsRep * product(repetitions));  | 
 | 55 | +  assert(kDim == 1 || kDim == 2);  | 
 | 56 | +  assert(nonKDim == 1 || nonKDim == 2);  | 
 | 57 | +  const unsigned bDim = 0;  | 
 | 58 | + | 
 | 59 | +  for (unsigned idx = 0; idx < elems.size(); ++idx) {  | 
 | 60 | +    auto inRepLinearIdx = idx % numElemsRep;  | 
 | 61 | +    auto repLinearIdx = idx / numElemsRep;  | 
 | 62 | +    auto inRepSpatialIdx =  | 
 | 63 | +        mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);  | 
 | 64 | +    auto repSpatialIdx =  | 
 | 65 | +        mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);  | 
 | 66 | +    OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],  | 
 | 67 | +                        inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],  | 
 | 68 | +                        inRepSpatialIdx[kDim]};  | 
 | 69 | +    res[key] = elems[idx];  | 
 | 70 | +  }  | 
 | 71 | +  return res;  | 
 | 72 | +}  | 
 | 73 | + | 
 | 74 | +} // namespace  | 
 | 75 | + | 
 | 76 | +namespace mlir::triton::gpu::intel {  | 
 | 77 | + | 
 | 78 | +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,  | 
 | 79 | +                                      const LLVMTypeConverter *typeConverter,  | 
 | 80 | +                                      ConversionPatternRewriter &rewriter,  | 
 | 81 | +                                      FMAVectorMultiplier &multiplier) {  | 
 | 82 | +  auto *ctx = rewriter.getContext();  | 
 | 83 | +  auto loc = op.getLoc();  | 
 | 84 | + | 
 | 85 | +  auto A = op.getA();  | 
 | 86 | +  auto D = op.getResult();  | 
 | 87 | + | 
 | 88 | +  auto aTensorTy = cast<RankedTensorType>(A.getType());  | 
 | 89 | +  auto dTensorTy = cast<RankedTensorType>(D.getType());  | 
 | 90 | + | 
 | 91 | +  SmallVector<int64_t> aShapePerCTA =  | 
 | 92 | +      expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));  | 
 | 93 | +  auto dShapePerCTA =  | 
 | 94 | +      expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));  | 
 | 95 | + | 
 | 96 | +  BlockedEncodingAttr dLayout =  | 
 | 97 | +      cast<BlockedEncodingAttr>(dTensorTy.getEncoding());  | 
 | 98 | +  // TODO process A and B operand separately  | 
 | 99 | +  auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());  | 
 | 100 | +  auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());  | 
 | 101 | +  auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);  | 
 | 102 | + | 
 | 103 | +  Value llA = adaptor.getA();  | 
 | 104 | +  Value llB = adaptor.getB();  | 
 | 105 | + | 
 | 106 | +  auto sizePerThread = getContigPerThread(dTensorTy);  | 
 | 107 | +  auto numElemsPerThread = product(sizePerThread);  | 
 | 108 | +  SmallVector<unsigned> shapePerCTATile;  | 
 | 109 | +  for (auto [reg, thread, warp] :  | 
 | 110 | +       llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(),  | 
 | 111 | +                 dLayout.getWarpsPerCTA())) {  | 
 | 112 | +    shapePerCTATile.push_back(reg * thread * warp);  | 
 | 113 | +  }  | 
 | 114 | +  shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile));  | 
 | 115 | +  sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread));  | 
 | 116 | + | 
 | 117 | +  unsigned K = aShapePerCTA[2];  | 
 | 118 | + | 
 | 119 | +  unsigned threadTileShape[3];  | 
 | 120 | +  unsigned repetitions[3];  | 
 | 121 | +  for (int i = 0; i < 3; ++i) {  | 
 | 122 | +    repetitions[i] =  | 
 | 123 | +        ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));  | 
 | 124 | +  }  | 
 | 125 | + | 
 | 126 | +  auto has = getValueTableFromStructFMA(  | 
 | 127 | +      llA, {sizePerThread[0], sizePerThread[1], K},  | 
 | 128 | +      {repetitions[0], repetitions[1], 1},  | 
 | 129 | +      /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);  | 
 | 130 | +  auto hbs = getValueTableFromStructFMA(  | 
 | 131 | +      llB, {sizePerThread[0], K, sizePerThread[2]},  | 
 | 132 | +      {repetitions[0], 1, repetitions[2]},  | 
 | 133 | +      /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);  | 
 | 134 | + | 
 | 135 | +  SmallVector<Value> acc = cc;  | 
 | 136 | + | 
 | 137 | +  for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)  | 
 | 138 | +    for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)  | 
 | 139 | +      for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)  | 
 | 140 | +        for (unsigned b = 0; b < sizePerThread[0]; ++b)  | 
 | 141 | +          for (unsigned m = 0; m < sizePerThread[1]; ++m)  | 
 | 142 | +            for (unsigned n = 0; n < sizePerThread[2]; ++n) {  | 
 | 143 | +              SmallVector<unsigned> multiDimAccumIdx = {b, m, n};  | 
 | 144 | +              unsigned linearInRepIdx =  | 
 | 145 | +                  LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder);  | 
 | 146 | +              SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};  | 
 | 147 | +              unsigned linearRepIdx =  | 
 | 148 | +                  LLVM::linearize(multiDimRepIdx, repetitions, repOrder);  | 
 | 149 | +              unsigned linearAccumIdx =  | 
 | 150 | +                  linearInRepIdx + linearRepIdx * numElemsPerThread;  | 
 | 151 | + | 
 | 152 | +              SmallVector<Value> aOpVector;  | 
 | 153 | +              SmallVector<Value> bOpVector;  | 
 | 154 | + | 
 | 155 | +              for (unsigned k = 0; k < K; ++k) {  | 
 | 156 | +                aOpVector.push_back(has.at({bRep, mRep, b, m, k}));  | 
 | 157 | +                bOpVector.push_back(hbs.at({bRep, nRep, b, n, k}));  | 
 | 158 | +              }  | 
 | 159 | + | 
 | 160 | +              acc[linearAccumIdx] = multiplier.multiplyVectors(  | 
 | 161 | +                  aOpVector, bOpVector, acc[linearAccumIdx]);  | 
 | 162 | +            }  | 
 | 163 | + | 
 | 164 | +  auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);  | 
 | 165 | +  rewriter.replaceOp(op, res);  | 
 | 166 | + | 
 | 167 | +  return success();  | 
 | 168 | +}  | 
 | 169 | + | 
 | 170 | +} // namespace mlir::triton::gpu::intel  | 
0 commit comments