@@ -497,24 +497,27 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
497497 if (!isa_and_nonnull<BlockedEncodingAttr>(oldRetType.getEncoding ()))
498498 return rewriter.notifyMatchFailure (
499499 dotOp, " expected blocked encoding result tensor" );
500-
501- if (dotOp. getRhsScale () )
502- return rewriter.notifyMatchFailure (dotOp, " NYI: RHS scale " );
500+ unsigned rank = oldRetType. getRank ();
501+ if (rank == 3 )
502+ return rewriter.notifyMatchFailure (dotOp, " NYI: 3d case " );
503503
504504 TensorValue a = dotOp.getLhs ();
505505 TensorValue b = dotOp.getRhs ();
506506 TensorValue aScale = dotOp.getLhsScale ();
507+ TensorValue bScale = dotOp.getRhsScale ();
508+ if (aScale && bScale)
509+ return rewriter.notifyMatchFailure (dotOp, " NYI: both LHS and RHS scale" );
510+
507511 ScaleDotElemType aElemType = dotOp.getLhsType ();
508512 ScaleDotElemType bElemType = dotOp.getRhsType ();
509-
510- if (!(aElemType == ScaleDotElemType::E2M1 ||
511- aElemType == ScaleDotElemType::E4M3 ||
512- aElemType == ScaleDotElemType::E5M2))
513- return rewriter.notifyMatchFailure (dotOp, " NYI: non-mxfp8/mxfp4 LHS" );
514- if (!(bElemType == ScaleDotElemType::E4M3 ||
515- bElemType == ScaleDotElemType::E5M2 ||
516- bElemType == ScaleDotElemType::BF16))
517- return rewriter.notifyMatchFailure (dotOp, " NYI: non-fp8/bf16 RHS" );
513+ auto supportsTypes = [](ScaleDotElemType elemType) {
514+ return elemType == ScaleDotElemType::E2M1 ||
515+ elemType == ScaleDotElemType::E4M3 ||
516+ elemType == ScaleDotElemType::E5M2 ||
517+ elemType == ScaleDotElemType::BF16;
518+ };
519+ if (!supportsTypes (aElemType) || !supportsTypes (bElemType))
520+ return rewriter.notifyMatchFailure (dotOp, " NYI: mxfp6 operand" );
518521
519522 MLIRContext *ctx = dotOp.getContext ();
520523 auto moduleOp = dotOp->getParentOfType <ModuleOp>();
@@ -534,27 +537,30 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
534537 unsigned kDim = mfmaInstr.value ().getKDim ();
535538 unsigned kBase = mfmaInstr.value ().getKBase ();
536539
537- // If A tensor contains mxfp4, we pack every two values into one int8 value
538- // there. For such cases, we have different initial kWidth for LHS and RHS,
539- // which will be "fixed" later by using upcast_mxfp to convert LHS to
540- // unpacked values. For such packed cases, we cannot support flexible kPack
541- // choices from the developer--it just does not apply here. So mandate the
542- // choice here.
543- bool isPacked = aElemType == ScaleDotElemType::E2M1;
544- unsigned kWdiths [] = {isPacked ? 4 : kBase * kPack ,
545- isPacked ? 8 : kBase * kPack };
546-
547- // For A tensor, 32 consecutive elements along K dim share the same scale.
540+ // For mxfp4 A/B tensor, we pack every two values into one int8 value there.
541+ // For such cases, we have different initial kWidth for LHS and RHS, which
542+ // will be "fixed" later by using upcast_mxfp to convert LHS to unpacked
543+ // values. For such packed cases, we cannot support flexible kPack choices
544+ // from the developer--it just does not apply here. So mandate the choice
545+ // here.
546+ bool isAPacked = aElemType == ScaleDotElemType::E2M1;
547+ bool isBPacked = bElemType == ScaleDotElemType::E2M1;
548+ bool isPacked = isAPacked || isBPacked;
549+ unsigned kWdiths [] = {isPacked ? (isAPacked ? 4 : 8 ) : kBase * kPack ,
550+ isPacked ? (isAPacked ? 8 : 4 ) : kBase * kPack };
551+
552+ // For A/B tensor, 32 consecutive elements along K dim share the same scale.
548553 // We'd like to keep the scale values together with the base values in the
549554 // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA
550- // = 1 along the N dimension.
551- SmallVector<unsigned , 3 > warpsPerCTA (oldRetType.getRank (), 1 );
552- warpsPerCTA.front () = numWarps;
555+ // = 1 along the N/M dimension for the mxfp A/B case. We achieve that by
556+ // setting the M/N dimension as numWarps.
557+ SmallVector<unsigned , 2 > mfmaWarpsPerCTA (rank, 1 );
558+ mfmaWarpsPerCTA[aScale ? 0 : 1 ] = numWarps;
553559
554560 // Always use transposed mfma layout. This enables larger vectorization
555561 // for global store instructions.
556562 auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get (
557- ctx, /* versionMajor=*/ mfmaVersion, /* versionMinor=*/ 0 , warpsPerCTA ,
563+ ctx, /* versionMajor=*/ mfmaVersion, /* versionMinor=*/ 0 , mfmaWarpsPerCTA ,
558564 /* instrShape=*/ mDim , nDim, /* isTransposed=*/ true , ctaLayout);
559565
560566 auto newRetType = RankedTensorType::get (
@@ -571,11 +577,9 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
571577 auto newVType = RankedTensorType::get (
572578 vType.getShape (), vType.getElementType (), newVEncoding);
573579 v = rewriter.create <ttg::ConvertLayoutOp>(v.getLoc (), newVType, v);
574- if (type == ScaleDotElemType::BF16)
575- return v;
576- // Don't need to covert int8 holding mxfp4 for A--the upcast_mxfp op can
580+ // Don't need to covert int8 holding mxfp4--the upcast_mxfp op can
577581 // take int8 tensor as input.
578- if (idx == 0 && type == ScaleDotElemType::E2M1)
582+ if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::E2M1)
579583 return v;
580584
581585 auto vTypeBf16 = RankedTensorType::get (
@@ -586,27 +590,42 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
586590 a = toMMABf16 (a, 0 , aElemType);
587591 b = toMMABf16 (b, 1 , bElemType);
588592
589- // We need to have "matching" encoding between the A tensor and A scale
593+ // We need to have "matching" encoding between the main tensor and scale
590594 // tensor to make sure the scale values needed is in the same warp. So we
591595 // adopt the same CTA layout and warps per CTA. The warp dimensions needs to
592- // match along M dimension too. With in a warp, we have 64 threads. We let
593- // each thread read in one scale value. So we need a threadsPerWarp = mDim
594- // along M dimension.
596+ // match along M/N dimension too. With in a warp, we have 64 threads. We let
597+ // each thread read in one scale value. So we need a threadsPerWarp =
598+ // mDim/nDim along M/N dimension. Note that For MFMA intrinsics, mDim is
599+ // always the same as nDim. And for scaled dot scale tensor, we always have
600+ // K as the innermost dimension. So we have the same threadsPerWarp in the
601+ // below no matter A or B scale. Similarly for warpsPerCTA, the non-K
602+ // dimension is always at index 0.
603+ assert (mDim == nDim);
595604 SmallVector<unsigned , 2 > threadsPerWarp = {mDim , numThreads / mDim };
605+ SmallVector<unsigned , 2 > blockWarpsPerCTA (rank, 1 );
606+ blockWarpsPerCTA[0 ] = numWarps;
596607 auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get (
597- ctx, {1 , 1 }, threadsPerWarp, warpsPerCTA, {1 , 0 }, ctaLayout);
608+ ctx, {1 , 1 }, threadsPerWarp, blockWarpsPerCTA, {1 , 0 }, ctaLayout);
609+
610+ auto upcastMXFP = [&](TensorValue main, TensorValue scale,
611+ ScaleDotElemType elemType) -> Value {
612+ if (!scale)
613+ return main;
598614
599- auto newScaleType = RankedTensorType::get (aScale. getType (). getShape (),
600- aScale .getType ().getElementType (),
601- newScaleEncoding);
602- aScale = rewriter.create <ttg::ConvertLayoutOp>(aScale .getLoc (),
603- newScaleType, aScale );
615+ auto newScaleType = RankedTensorType::get (
616+ scale. getType (). getShape (), scale .getType ().getElementType (),
617+ newScaleEncoding);
618+ auto convOp = rewriter.create <ttg::ConvertLayoutOp>(scale .getLoc (),
619+ newScaleType, scale );
604620
605- auto scaledA = rewriter.create <triton::gpu::UpcastMXFPOp>(
606- dotOp.getLoc (), a, aScale, dotOp.getLhsType ());
621+ return rewriter.create <triton::gpu::UpcastMXFPOp>(dotOp.getLoc (), main,
622+ convOp, elemType);
623+ };
607624
608- auto newDot =
609- rewriter.create <DotOp>(dotOp.getLoc (), newRetType, scaledA, b, newAcc);
625+ Value scaledA = upcastMXFP (a, aScale, dotOp.getLhsType ());
626+ Value scaledB = upcastMXFP (b, bScale, dotOp.getRhsType ());
627+ auto newDot = rewriter.create <DotOp>(dotOp.getLoc (), newRetType, scaledA,
628+ scaledB, newAcc);
610629 rewriter.replaceOpWithNewOp <ttg::ConvertLayoutOp>(dotOp, oldRetType,
611630 newDot);
612631 return success ();
0 commit comments