@@ -165,6 +165,67 @@ struct DotOpMFMAConversionHelper {
165165 return processSubBlocks (numSubBlocks, acc, false , true );
166166 }
167167
168+ // / Dot operand layout minimal tile is kDimInstrSize elements across
169+ // / K dimension. If dot operand K dimension is smaller, layout
170+ // / assigns tensor elements to multiple different hardware locations.
171+ // / In this case mfma instruction adds elements in accumulator
172+ // / multiple times.
173+ // /
174+ // / Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
175+ // / Consider instruction K size is 4,
176+ // / in this case operands will be duplicated:
177+ // / A' = [1,2,1,2] B' = [3,4,3,4]
178+ // / C' = (1*3+2*4) + (1*3+2*4) = 22
179+ // /
180+ // / Following code adjusts accumulator values in such cases.
181+ // / If accumulator is integer, shift accumulator right by
182+ // / log2(duplicationRate). If accumulator is float, multiply accum
183+ // / with 1/duplicationRate constant.
184+ void adjustAccForSmallKDim (SmallVector<Value> &fc, Value &acc, Type dstElemTy,
185+ int b, int m, int n, int64_t numRepM,
186+ int64_t numRepN, int64_t kDimInstrSize ,
187+ int64_t kDimOperandSize ,
188+ unsigned elemsPerVec) const {
189+ auto tb = TritonLLVMOpBuilder (loc, rewriter);
190+ for (unsigned v = 0 ; v < elemsPerVec; ++v) {
191+ Value accElem = tb.extract_element (dstElemTy, acc, tb.i32_val (v));
192+ if (kDimInstrSize > kDimOperandSize ) {
193+ assert (kDimInstrSize % kDimOperandSize == 0 );
194+ int duplicationRate = kDimInstrSize / kDimOperandSize ;
195+ assert (llvm::isPowerOf2_32 (duplicationRate));
196+ if (dstElemTy.isInteger ()) {
197+ auto shiftSize = llvm::Log2_32 (duplicationRate);
198+ assert (!accElem.getType ().isUnsignedInteger () &&
199+ " MFMA uses signed accumulator" );
200+ accElem = tb.ashr (accElem, tb.i32_val (shiftSize));
201+ } else {
202+ auto multiplierAttr =
203+ rewriter.getFloatAttr (dstElemTy, 1.0 / duplicationRate);
204+ auto multiplierVal =
205+ rewriter.create <LLVM::ConstantOp>(loc, dstElemTy, multiplierAttr);
206+ accElem = tb.fmul (accElem, multiplierVal);
207+ }
208+ }
209+ auto linearIdx = b * numRepM * numRepN * elemsPerVec +
210+ m * numRepN * elemsPerVec + n * elemsPerVec + v;
211+ fc[linearIdx] = accElem;
212+ }
213+ }
214+
215+ void packAndReplaceResult (DotOp &op, SmallVector<Value> &fc,
216+ FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
217+ Type elemtTy, size_t mmaCount) const {
218+ Type structTy = LLVM::LLVMStructType::getLiteral (
219+ ctx, SmallVector<Type>(fc.size (), dstElemTy));
220+ Value res = packLLElements (loc, typeConverter, fc, rewriter, structTy);
221+
222+ setNumGeneratedMMAs (op, mmaCount, maybeMfmaInsn->getMDim (),
223+ maybeMfmaInsn->getNDim (), maybeMfmaInsn->getKDim (),
224+ elemtTy);
225+
226+ rewriter.replaceOp (op, res);
227+ }
228+
168229 // Conduct the Dot conversion.
169230 LogicalResult convertDot (DotOp op, DotOpAdaptor adaptor) const {
170231 auto tb = TritonLLVMOpBuilder (loc, rewriter);
@@ -243,11 +304,6 @@ struct DotOpMFMAConversionHelper {
243304 auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
244305
245306 Value firstMfma;
246- auto setFirstMfma = [&](Value mfma) {
247- if (!firstMfma)
248- firstMfma = mfma;
249- };
250-
251307 auto vecTy = vec_ty (dstElemTy, elemsPerVec);
252308 for (int b = 0 ; b < numRepB; ++b) {
253309 for (int m = 0 ; m < numRepM; ++m) {
@@ -269,49 +325,13 @@ struct DotOpMFMAConversionHelper {
269325 operandA[kPack ][{b, m, k}], acc)
270326 : generateMFMAOp (mfmaInsnName, operandA[kPack ][{b, m, k}],
271327 operandB[kPack ][{b, n, k}], acc);
272- setFirstMfma (acc);
328+ if (!firstMfma)
329+ firstMfma = acc;
273330 }
274331 }
275332 acc = reduceSubBlocks (subBlocks, acc);
276- for (unsigned v = 0 ; v < elemsPerVec; ++v) {
277- Value accElem = tb.extract_element (dstElemTy, acc, tb.i32_val (v));
278- // Dot operand layout minimal tile is kDimInstrSize elements across
279- // K dimension. If dot operand K dimension is smaller, layout
280- // assigns tensor elements to multiple different hardware locations.
281- // In this case mfma instruction adds elements in accumulator
282- // multiple times.
283- //
284- // Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
285- // Consider instruction K size is 4,
286- // in this case operands will be duplicated:
287- // A' = [1,2,1,2] B' = [3,4,3,4]
288- // C' = (1*3+2*4) + (1*3+2*4) = 22
289- //
290- // Following code adjusts accumulator values in such cases.
291- // If accumulator is integer, shift accumulator right by
292- // log2(duplicationRate). If accumulator is float, multiply accum
293- // with 1/duplicationRate constant.
294- if (kDimInstrSize > kDimOperandSize ) {
295- assert (kDimInstrSize % kDimOperandSize == 0 );
296- int duplicationRate = kDimInstrSize / kDimOperandSize ;
297- assert (llvm::isPowerOf2_32 (duplicationRate));
298- if (dstElemTy.isInteger ()) {
299- auto shiftSize = llvm::Log2_32 (duplicationRate);
300- assert (!accElem.getType ().isUnsignedInteger () &&
301- " MFMA uses signed accumulator" );
302- accElem = tb.ashr (accElem, tb.i32_val (shiftSize));
303- } else {
304- auto multiplierAttr =
305- rewriter.getFloatAttr (dstElemTy, 1.0 / duplicationRate);
306- auto multiplierVal = rewriter.create <LLVM::ConstantOp>(
307- loc, dstElemTy, multiplierAttr);
308- accElem = tb.fmul (accElem, multiplierVal);
309- }
310- }
311- auto linearIdx = b * numRepM * numRepN * elemsPerVec +
312- m * numRepN * elemsPerVec + n * elemsPerVec + v;
313- fc[linearIdx] = accElem;
314- }
333+ adjustAccForSmallKDim (fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
334+ kDimInstrSize , kDimOperandSize , elemsPerVec);
315335 }
316336 }
317337 }
@@ -325,19 +345,9 @@ struct DotOpMFMAConversionHelper {
325345 if (setPrioOp && firstMfma)
326346 setPrioOp->moveAfter (firstMfma.getDefiningOp ());
327347
328- // replace with new packed result
329- Type structTy = LLVM::LLVMStructType::getLiteral (
330- ctx, SmallVector<Type>(fc.size (), dstElemTy));
331- Value res = packLLElements (loc, typeConverter, fc, rewriter, structTy);
332-
333- Type elemtTy = elemTyA;
334348 const size_t mmaCount =
335349 numRepB * numRepM * numRepN * numRepK * kWidth / kBase ;
336- setNumGeneratedMMAs (op, mmaCount, maybeMfmaInsn->getMDim (),
337- maybeMfmaInsn->getNDim (), maybeMfmaInsn->getKDim (),
338- elemtTy);
339-
340- rewriter.replaceOp (op, res);
350+ packAndReplaceResult (op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
341351
342352 return success ();
343353 }
0 commit comments