@@ -13,24 +13,51 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
1313using ::mlir::triton::gpu::getShapePerCTA;
1414using ::mlir::triton::gpu::getSizePerThread;
1515
16- using ValueTableFMA = std::map<std::tuple<int , int , int >, Value>;
16+ // / \brief spatial position of repetition and register of a given value
17+ struct OperandValueKey {
18+ unsigned bRepIdx, nonKRepIdx;
19+ unsigned bIdx, nonKIdx, kIdx ;
20+
21+ bool operator ==(const OperandValueKey &other) const {
22+ return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23+ bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24+ kIdx == other.kIdx );
25+ }
26+ };
27+
28+ template <> struct std ::hash<OperandValueKey> {
29+ std::size_t operator ()(const OperandValueKey &k) const {
30+ return llvm::hash_combine (k.bRepIdx , k.nonKRepIdx , k.bIdx , k.nonKIdx ,
31+ k.kIdx );
32+ }
33+ };
34+
35+ using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
1736
18- static ValueTableFMA
19- getValueTableFromStructFMA (Value val, ArrayRef<unsigned > perTileShape,
20- unsigned kDim , unsigned nonKDim,
21- ConversionPatternRewriter &rewriter, Location loc,
22- ArrayRef<unsigned > order) {
37+ static ValueTableFMA getValueTableFromStructFMA (
38+ Value val, ArrayRef<unsigned > perRepShape, ArrayRef<unsigned > repetitions,
39+ unsigned kDim , unsigned nonKDim, ConversionPatternRewriter &rewriter,
40+ Location loc, ArrayRef<unsigned > inRepOrder, ArrayRef<unsigned > repOrder) {
2341 ValueTableFMA res;
2442 auto elems = unpackLLElements (loc, val, rewriter);
25- assert (perTileShape.size () == 3 );
26- assert (elems.size () == product (perTileShape));
43+ assert (perRepShape.size () == 3 );
44+ auto numElemsRep = product (perRepShape);
45+ assert (elems.size () == numElemsRep * product (repetitions));
2746 assert (kDim == 1 || kDim == 2 );
2847 assert (nonKDim == 1 || nonKDim == 2 );
2948 const unsigned bDim = 0 ;
3049
3150 for (unsigned idx = 0 ; idx < elems.size (); ++idx) {
32- auto spatialIdx = mlir::LLVM::delinearize (idx, perTileShape, order);
33- res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim ]}] = elems[idx];
51+ auto inRepLinearIdx = idx % numElemsRep;
52+ auto repLinearIdx = idx / numElemsRep;
53+ auto inRepSpatialIdx =
54+ mlir::LLVM::delinearize (inRepLinearIdx, perRepShape, inRepOrder);
55+ auto repSpatialIdx =
56+ mlir::LLVM::delinearize (repLinearIdx, repetitions, repOrder);
57+ OperandValueKey key{repSpatialIdx[0 ], repSpatialIdx[nonKDim],
58+ inRepSpatialIdx[0 ], inRepSpatialIdx[nonKDim],
59+ inRepSpatialIdx[kDim ]};
60+ res[key] = elems[idx];
3461 }
3562 return res;
3663}
@@ -54,46 +81,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
5481
5582 BlockedEncodingAttr dLayout =
5683 cast<BlockedEncodingAttr>(dTensorTy.getEncoding ());
57- auto order = expandMatrixOrderWithBatch (dLayout.getOrder ());
84+ // TODO process A and B operand separately
85+ auto inRepOrder = expandMatrixOrderWithBatch (dLayout.getOrder ());
86+ auto repOrder = expandMatrixOrderWithBatch (dLayout.getRepOrder ());
5887 auto cc = unpackLLElements (loc, adaptor.getC (), rewriter);
5988
6089 Value llA = adaptor.getA ();
6190 Value llB = adaptor.getB ();
6291
6392 auto sizePerThread =
6493 expandMatrixShapeWithBatch (ArrayRef (getSizePerThread (dLayout)));
94+ auto numElemsPerThread = product (sizePerThread);
6595 auto shapePerCTATile =
6696 expandMatrixShapeWithBatch (ArrayRef (getShapePerCTATile (dLayout)));
6797
6898 unsigned K = aShapePerCTA[2 ];
6999
70- unsigned perThreadShape[3 ];
100+ unsigned threadTileShape[3 ];
101+ unsigned repetitions[3 ];
71102 for (int i = 0 ; i < 3 ; ++i) {
72- unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
73- numRep = std::max (static_cast <unsigned >(1 ), numRep);
74- perThreadShape[i] = numRep * sizePerThread[i];
103+ repetitions[i] =
104+ ceil (dShapePerCTA[i], static_cast <int64_t >(shapePerCTATile[i]));
75105 }
76106
77107 auto has = getValueTableFromStructFMA (
78- llA, {perThreadShape[0 ], perThreadShape[1 ], K},
79- /* kDim*/ 2 , /* nonKDim*/ 1 , rewriter, loc, order);
108+ llA, {sizePerThread[0 ], sizePerThread[1 ], K},
109+ {repetitions[0 ], repetitions[1 ], 1 },
110+ /* kDim*/ 2 , /* nonKDim*/ 1 , rewriter, loc, inRepOrder, repOrder);
80111 auto hbs = getValueTableFromStructFMA (
81- llB, {perThreadShape[0 ], K, perThreadShape[2 ]},
82- /* kDim*/ 1 , /* nonKDim*/ 2 , rewriter, loc, order);
112+ llB, {sizePerThread[0 ], K, sizePerThread[2 ]},
113+ {repetitions[0 ], 1 , repetitions[2 ]},
114+ /* kDim*/ 1 , /* nonKDim*/ 2 , rewriter, loc, inRepOrder, repOrder);
83115
84116 SmallVector<Value> acc = cc;
85117
86- for (unsigned b = 0 ; b < perThreadShape[0 ]; ++b)
87- for (unsigned m = 0 ; m < perThreadShape[1 ]; ++m)
88- for (unsigned n = 0 ; n < perThreadShape[2 ]; ++n) {
89- SmallVector<unsigned > multiDimAccumIdx = {b, m, n};
90- unsigned linearAccumIdx =
91- linearize (multiDimAccumIdx, perThreadShape, order);
92- for (unsigned k = 0 ; k < K; ++k) {
93- acc[linearAccumIdx] = rewriter.create <LLVM::FMulAddOp>(
94- loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
95- }
96- }
118+ for (unsigned bRep = 0 ; bRep < repetitions[0 ]; ++bRep)
119+ for (unsigned mRep = 0 ; mRep < repetitions[1 ]; ++mRep )
120+ for (unsigned nRep = 0 ; nRep < repetitions[2 ]; ++nRep)
121+ for (unsigned b = 0 ; b < sizePerThread[0 ]; ++b)
122+ for (unsigned m = 0 ; m < sizePerThread[1 ]; ++m)
123+ for (unsigned n = 0 ; n < sizePerThread[2 ]; ++n) {
124+ SmallVector<unsigned > multiDimAccumIdx = {b, m, n};
125+ unsigned linearInRepIdx =
126+ linearize (multiDimAccumIdx, sizePerThread, inRepOrder);
127+ SmallVector<unsigned > multiDimRepIdx = {bRep, mRep , nRep};
128+ unsigned linearRepIdx =
129+ linearize (multiDimRepIdx, repetitions, repOrder);
130+ unsigned linearAccumIdx =
131+ linearInRepIdx + linearRepIdx * numElemsPerThread;
132+ for (unsigned k = 0 ; k < K; ++k) {
133+ auto aOp = has[{bRep, mRep , b, m, k}];
134+ auto bOp = hbs[{bRep, nRep, b, n, k}];
135+ acc[linearAccumIdx] = rewriter.create <LLVM::FMulAddOp>(
136+ loc, aOp, bOp, acc[linearAccumIdx]);
137+ }
138+ }
97139
98140 auto res = packLLElements (loc, typeConverter, acc, rewriter, dTensorTy);
99141 rewriter.replaceOp (op, res);
0 commit comments