@@ -199,16 +199,37 @@ chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
199199 aElemType, bElemType, withScale, allowXF32);
200200
201201 // Fallback to FMA if the M/N dim is not supported by MFMA.
202- if (failed (maybeMfmaIntrinsic))
202+ if (failed (maybeMfmaIntrinsic)) {
203+ mlir::emitRemark (loc) << " Unable to select MFMA intrinsic for the request: "
204+ << " version=" << mfmaVersion << " , result-shape=("
205+ << M << " x" << N << " ), selected-tiles=(" << mDim
206+ << " x" << nDim << " ), inputKSize=" << inputKSize
207+ << " , aElemType=" << aElemType
208+ << " , bElemType=" << bElemType
209+ << " , withScale=" << (withScale ? " true" : " false" )
210+ << " , allowXF32=" << (allowXF32 ? " true" : " false" )
211+ << (enforcedNonKDim != 0
212+ ? (llvm::Twine (" , enforcedNonKDim=" ) +
213+ llvm::Twine (enforcedNonKDim))
214+ .str ()
215+ : " " );
203216 return failure ();
217+ }
204218
205219 kDim = maybeMfmaIntrinsic->kDim ;
206220 assert (kDim != 0 );
207221 assert (enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0 ));
208222 // If inputKSize % kDim != 0 (including the case where inputKSize < kDim),
209223 // this layout will introduce data duplication.
210- if (inputKSize % kDim != 0 )
224+ if (inputKSize % kDim != 0 ) {
225+ mlir::emitRemark (loc)
226+ << " Unable to select MFMA intrinsic '" << maybeMfmaIntrinsic->name
227+ << " ' as MFMA intrinsic k-dimension size kDim=" << kDim
228+ << " , which is not a multiple of tile k-dimension size inputKSize="
229+ << inputKSize
230+ << " . Using this intrinsic would introduce data duplication." ;
211231 return failure ();
232+ }
212233 return maybeMfmaIntrinsic;
213234}
214235
@@ -548,11 +569,15 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
548569 chooseMfmaInstruction (dotOp, mfmaVersion, nonKDim, withScale);
549570 if (failed (mfmaInstr)) {
550571 if (!withScale) {
551- return failure ();
572+ return rewriter.notifyMatchFailure (
573+ dotOp,
574+ " Unable to choose preferable MFMA intrinsic for dot operation." );
552575 }
553576 mfmaInstr = chooseMfmaInstruction (dotOp, mfmaVersion, nonKDim, false );
554- if (failed (mfmaInstr))
555- return failure ();
577+ if (failed (mfmaInstr)) {
578+ return rewriter.notifyMatchFailure (
579+ dotOp, " Unable to choose MFMA intrinsic for dot operation." );
580+ }
556581
557582 withScale = false ;
558583 }
@@ -769,7 +794,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
769794 FailureOr<MfmaIntrinsic> mfmaInstr =
770795 chooseMfmaInstruction (dotOp, mfmaVersion, nonKDim, useFp16);
771796 if (failed (mfmaInstr))
772- return rewriter.notifyMatchFailure (dotOp, " cannot choose mfma intrinsic" );
797+ return rewriter.notifyMatchFailure (
798+ dotOp, " Unable to choose MFMA intrinsic for scaled dot operation." );
773799
774800 if (useFp16) {
775801 dotOp.emitRemark (
@@ -895,6 +921,13 @@ class DecomposeAMDScaledBlocked final : public ttg::DecomposeScaledBlocked {
895921 : ttg::DecomposeScaledBlocked(context, benefit) {}
896922 using TensorValue = TypedValue<RankedTensorType>;
897923
924+ LogicalResult matchAndRewrite (tt::DotScaledOp dotOp,
925+ PatternRewriter &rewriter) const override {
926+ dotOp.emitRemark () << " Decomposing scaled dot operation into regular dot "
927+ " operation with explicit scaling." ;
928+ return ttg::DecomposeScaledBlocked::matchAndRewrite (dotOp, rewriter);
929+ }
930+
898931 RankedTensorType getScaleType (RankedTensorType vType, int32_t kDim ,
899932 bool isFp4) const {
900933 if (!isFp4)
@@ -1018,9 +1051,11 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
10181051 // Choose a suitable Scaled MFMA instruction for this scaled dot op.
10191052 FailureOr<MfmaIntrinsic> mfmaInstr =
10201053 chooseMfmaInstruction (dotOp, mfmaVersion, nonKDim);
1021- if (failed (mfmaInstr))
1054+ if (failed (mfmaInstr)) {
10221055 return rewriter.notifyMatchFailure (dotOp,
1023- " cannot choose scaled mfma intrinsic" );
1056+ " Unable to choose preferable MFMA "
1057+ " intrinsic for scaled dot operation." );
1058+ }
10241059
10251060 auto mDim = mfmaInstr->mDim ;
10261061 auto nDim = mfmaInstr->nDim ;
@@ -1474,7 +1509,8 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
14741509 FailureOr<WmmaIntrinsic> wmmaInstr =
14751510 chooseWmmaInstruction (dotOp, operandTypes, wmmaVersion, nonKDim);
14761511 if (failed (wmmaInstr)) {
1477- return failure ();
1512+ return rewriter.notifyMatchFailure (
1513+ dotOp, " Unable to choose WMMA intrinsic for dot operation." );
14781514 }
14791515
14801516 auto mDim = wmmaInstr->mDim ;
@@ -1625,7 +1661,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16251661 LogicalResult tryAccelerateF16WithVDot (DotOp dotOp, PatternRewriter &rewriter,
16261662 const DotElTypes &dotTypes) const {
16271663 if (!AMD::supportsVDot (arch))
1628- return failure ();
1664+ return rewriter.notifyMatchFailure (
1665+ dotOp, " Target architecture does not support V_DOT instruction." );
16291666
16301667 // If this is fp16 x fp16 ->fp16 case prioritize using v_dot.
16311668 auto aOpType = dotOp.getA ().getType ();
@@ -1641,7 +1678,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16411678 rewriter.replaceOp (dotOp, newD);
16421679 return success ();
16431680 }
1644- return failure ();
1681+ return rewriter.notifyMatchFailure (
1682+ dotOp, " Unable to choose V_DOT instruction for dot operation." );
16451683 }
16461684
16471685 LogicalResult tryLegalizeFMA (DotOp dotOp, PatternRewriter &rewriter,
@@ -1687,7 +1725,10 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16871725 LogicalResult matchAndRewrite (DotOp dotOp,
16881726 PatternRewriter &rewriter) const override {
16891727 if (!isa<BlockedEncodingAttr>(dotOp.getD ().getType ().getEncoding ()))
1690- return failure ();
1728+ return rewriter.notifyMatchFailure (
1729+ dotOp, " expected blocked encoding result tensor" );
1730+
1731+ dotOp.emitRemark () << " Attempting to map dot operation to FMA intrinsic." ;
16911732
16921733 DotElTypes dotTypes;
16931734 dotTypes.a = dotOp.getA ().getType ().getElementType ();
@@ -1697,7 +1738,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16971738
16981739 // Check that dot is not legalized already
16991740 if (isLegalFMAForm (dotOp, dotTypes)) {
1700- return failure ();
1741+ return rewriter.notifyMatchFailure (
1742+ dotOp, " Dot operation is already in FMA form." );
17011743 }
17021744
17031745 // TODO: enable this condition, when fp32 -> fp16 cast works correctly
0 commit comments