@@ -37,7 +37,6 @@ using ::mlir::LLVM::AMD::shuffleXor;
3737using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
3838using ::mlir::triton::gpu::DotOperandEncodingAttr;
3939using ::mlir::triton::gpu::LinearEncodingAttr;
40- using ::mlir::triton::gpu::SwizzledSharedEncodingAttr;
4140
4241using ValueTable = std::map<std::array<int , 3 >, Value>;
4342
@@ -75,12 +74,12 @@ struct DotOpMFMAConversionHelper {
7574 : mfmaLayout(mfmaLayout), rewriter(rewriter),
7675 typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
7776
78- Value generateMFMAOp (StringRef mfmaInsnName , Value valA, Value valB,
77+ Value generateMFMAOp (StringRef intrinsicName , Value valA, Value valB,
7978 Value valC) const {
8079 auto b = TritonLLVMOpBuilder (loc, rewriter);
8180 auto resType = valC.getType ();
8281 Value zeroFlag = b.i32_val (0 );
83- OperationState loweredOp (loc, mfmaInsnName );
82+ OperationState loweredOp (loc, intrinsicName );
8483 loweredOp.addTypes (resType);
8584 loweredOp.addOperands ({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
8685 return rewriter.create (loweredOp)->getResult (0 );
@@ -228,14 +227,15 @@ struct DotOpMFMAConversionHelper {
228227
229228 template <typename T>
230229 void packAndReplaceResult (T &op, SmallVector<Value> &fc,
231- FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
232- Type elemtTy, size_t mmaCount) const {
230+ const FailureOr<MfmaIntrinsic> &maybeMfmaIntrinsic,
231+ Type dstElemTy, Type elemtTy,
232+ size_t mmaCount) const {
233233 Type structTy = LLVM::LLVMStructType::getLiteral (
234234 ctx, SmallVector<Type>(fc.size (), dstElemTy));
235235 Value res = packLLElements (loc, typeConverter, fc, rewriter, structTy);
236236
237- setNumGeneratedMMAs (op, mmaCount, maybeMfmaInsn-> getMDim () ,
238- maybeMfmaInsn-> getNDim (), maybeMfmaInsn-> getKDim () ,
237+ setNumGeneratedMMAs (op, mmaCount, maybeMfmaIntrinsic-> mDim ,
238+ maybeMfmaIntrinsic-> nDim , maybeMfmaIntrinsic-> kDim ,
239239 elemtTy);
240240
241241 rewriter.replaceOp (op, res);
@@ -267,14 +267,15 @@ struct DotOpMFMAConversionHelper {
267267
268268 bool allowXF32 =
269269 op.getInputPrecision () == InputPrecision::TF32 && mfmaVersion == 3 ;
270- StringRef mfmaInsnName;
271- auto maybeMfmaInsn = MfmaInsn::selectMfma (
272- mDim , nDim, kDimOperandSize , elemTyA, elemTyB, mfmaVersion, allowXF32);
273- if (failed (maybeMfmaInsn))
270+ StringRef intrinsicName;
271+ FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic = MfmaIntrinsic::selectFor (
272+ mfmaVersion, mDim , nDim, kDimOperandSize , elemTyA, elemTyB,
273+ /* withScale=*/ false , allowXF32);
274+ if (failed (maybeMfmaIntrinsic))
274275 llvm::report_fatal_error (" No match found in MFMA database\n " );
275276
276- mfmaInsnName = maybeMfmaInsn-> getInsnName () ;
277- unsigned kBase = maybeMfmaInsn-> getKBase () ;
277+ intrinsicName = maybeMfmaIntrinsic-> name ;
278+ unsigned kBase = maybeMfmaIntrinsic-> kBase ;
278279
279280 auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding ());
280281 auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding ());
@@ -301,7 +302,7 @@ struct DotOpMFMAConversionHelper {
301302 auto numRepB = repA[0 ];
302303 assert (repA[0 ] == repB[0 ]);
303304
304- bool preserveBF16 = mfmaInsnName .contains (" .bf16" ) && mfmaVersion >= 4 ;
305+ bool preserveBF16 = intrinsicName .contains (" .bf16" ) && mfmaVersion >= 4 ;
305306 auto operandA = getValuesFromDotOperandLayoutStruct (
306307 loadedA, numRepB, numRepM, numRepK, kWidth , kBase ,
307308 aTensorTy.getElementType (), allowXF32, preserveBF16);
@@ -335,12 +336,13 @@ struct DotOpMFMAConversionHelper {
335336 acc = zeroAuxiliarBlocks (subBlocks, acc);
336337 for (int k = 0 ; k < numRepK; k++) {
337338 for (int kPack = 0 ; kPack < kWidth / kBase ; ++kPack ) {
338- acc =
339- mfmaLayout.getIsTransposed ()
340- ? generateMFMAOp (mfmaInsnName, operandB[kPack ][{b, n, k}],
341- operandA[kPack ][{b, m, k}], acc)
342- : generateMFMAOp (mfmaInsnName, operandA[kPack ][{b, m, k}],
343- operandB[kPack ][{b, n, k}], acc);
339+ acc = mfmaLayout.getIsTransposed ()
340+ ? generateMFMAOp (intrinsicName,
341+ operandB[kPack ][{b, n, k}],
342+ operandA[kPack ][{b, m, k}], acc)
343+ : generateMFMAOp (intrinsicName,
344+ operandA[kPack ][{b, m, k}],
345+ operandB[kPack ][{b, n, k}], acc);
344346 if (!firstMfma)
345347 firstMfma = acc;
346348 }
@@ -363,7 +365,8 @@ struct DotOpMFMAConversionHelper {
363365
364366 const size_t mmaCount =
365367 numRepB * numRepM * numRepN * numRepK * kWidth / kBase ;
366- packAndReplaceResult (op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
368+ packAndReplaceResult (op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA,
369+ mmaCount);
367370
368371 return success ();
369372 }
@@ -485,15 +488,15 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
485488 Location loc)
486489 : DotOpMFMAConversionHelper(mfmaLayout, rewriter, typeConverter, loc) {}
487490
488- Value generateScaledMFMAOp (MfmaInsn &mfmaInsn , Value valA, Value valB ,
489- Value valC, Value valScaleA,
491+ Value generateScaledMFMAOp (const MfmaIntrinsic &mfmaIntrinsic , Value valA,
492+ Value valB, Value valC, Value valScaleA,
490493 Value valScaleB) const {
491494 auto b = TritonLLVMOpBuilder (loc, rewriter);
492495 auto resType = valC.getType ();
493496 Value zeroFlag = b.i32_val (0 );
494- OperationState loweredOp (loc, mfmaInsn. getInsnName () );
495- int32_t cbsz = getMfmaF8F6F4MatrixFormat (mfmaInsn. getElementTypeA () );
496- int32_t blgp = getMfmaF8F6F4MatrixFormat (mfmaInsn. getElementTypeB () );
497+ OperationState loweredOp (loc, mfmaIntrinsic. name );
498+ int32_t cbsz = getMfmaF8F6F4MatrixFormat (mfmaIntrinsic. aElementType );
499+ int32_t blgp = getMfmaF8F6F4MatrixFormat (mfmaIntrinsic. bElementType );
497500 assert ((cbsz != -1 ) && (blgp != -1 ));
498501 loweredOp.addTypes (resType);
499502 loweredOp.addOperands ({valA, valB, valC, b.i32_val (cbsz), b.i32_val (blgp),
@@ -540,14 +543,16 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
540543
541544 auto ctx = op.getContext ();
542545 constexpr bool allowXF32 = false ;
543- auto maybeMfmaInsn = MfmaInsn::selectMfma (
544- mDim , nDim, kDimOperandSize , scaleDotElemTypeToMLIRType (ctx, aElemType),
545- scaleDotElemTypeToMLIRType (ctx, bElemType), mfmaVersion, allowXF32);
546- if (failed (maybeMfmaInsn))
546+ FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic =
547+ MfmaIntrinsic::selectFor (mfmaVersion, mDim , nDim, kDimOperandSize ,
548+ scaleDotElemTypeToMLIRType (ctx, aElemType),
549+ scaleDotElemTypeToMLIRType (ctx, bElemType),
550+ /* withScale=*/ false , allowXF32);
551+ if (failed (maybeMfmaIntrinsic))
547552 llvm::report_fatal_error (" No match found in MFMA database\n " );
548553
549- StringRef mfmaInsnName = maybeMfmaInsn-> getInsnName () ;
550- unsigned kBase = maybeMfmaInsn-> getKBase () ;
554+ StringRef intrinsicName = maybeMfmaIntrinsic-> name ;
555+ unsigned kBase = maybeMfmaIntrinsic-> kBase ;
551556 // Two fp4 are packed into an uint8.
552557 if (aElemType == ScaleDotElemType::E2M1) {
553558 kBase /= 2 ;
@@ -629,12 +634,12 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
629634 for (int k = 0 ; k < numRepK; k++) {
630635 for (int kPack = 0 ; kPack < kWidth / kBase ; ++kPack ) {
631636 acc = mfmaLayout.getIsTransposed ()
632- ? generateScaledMFMAOp (maybeMfmaInsn .value (),
637+ ? generateScaledMFMAOp (maybeMfmaIntrinsic .value (),
633638 operandB[kPack ][{b, n, k}],
634639 operandA[kPack ][{b, m, k}], acc,
635640 operandBScale[kPack ][{b, n, k}],
636641 operandAScale[kPack ][{b, m, k}])
637- : generateScaledMFMAOp (maybeMfmaInsn .value (),
642+ : generateScaledMFMAOp (maybeMfmaIntrinsic .value (),
638643 operandA[kPack ][{b, m, k}],
639644 operandB[kPack ][{b, n, k}], acc,
640645 operandAScale[kPack ][{b, m, k}],
@@ -661,7 +666,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
661666
662667 const size_t mmaCount =
663668 numRepB * numRepM * numRepN * numRepK * kWidth / kBase ;
664- packAndReplaceResult (op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
669+ packAndReplaceResult (op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA,
670+ mmaCount);
665671
666672 return success ();
667673 }
0 commit comments