@@ -149,6 +149,7 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
149149static Value
150150getSharedMemoryMMAOperand (Value v, mlir::PatternRewriter &rewriter, int opIdx,
151151 bool allowTranspose, bool isMMAv5Fp4Padded = false ,
152+ bool forceTranspose = false ,
152153 Operation *op = nullptr /* only for diagnostic*/ ) {
153154 OpBuilder::InsertionGuard g (rewriter);
154155 Value arg = v;
@@ -167,6 +168,8 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
167168 } else {
168169 newOrder = {1 , 0 };
169170 }
171+ if (forceTranspose)
172+ std::swap (newOrder[0 ], newOrder[1 ]);
170173 }
171174
172175 if (newOrder != order && op) {
@@ -648,49 +651,47 @@ class ScaledBlockedToMMAv5
648651
649652 bool IsAMixedPrecFp4 = false ;
650653 bool IsBMixedPrecFp4 = false ;
654+ bool isAFP4 = dotOp.getAElemType () == ScaleDotElemType::E2M1;
655+ bool isBFP4 = dotOp.getBElemType () == ScaleDotElemType::E2M1;
651656
652657 if (dotOp.getAElemType () != dotOp.getBElemType ()) {
653- if (dotOp. getAElemType () == ScaleDotElemType::E2M1 )
658+ if (isAFP4 )
654659 IsAMixedPrecFp4 = true ;
655- else if (dotOp. getBElemType () == ScaleDotElemType::E2M1 )
660+ else if (isBFP4 )
656661 IsBMixedPrecFp4 = true ;
657662 }
658-
663+ // If we use txgen05.mma.kind.mxf864 we need to padd the fp4 operands:
664+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-packing-formats-mxf8f6f4-smem
665+ bool isMMAv5Fp4PaddedLhs = IsAMixedPrecFp4 || !dotOp.getLhsKPack ();
666+ bool isMMAv5Fp4PaddedRhs = IsBMixedPrecFp4 || !dotOp.getRhsKPack ();
659667 // For mixed-precision fp4 operands, set allowTranspose = false, to force
660668 // the packed axis, K, to be contiguous in SMEM
661669 a = getSharedMemoryMMAOperand (a, rewriter, 0 ,
662- /* allowTranspose=*/ !IsAMixedPrecFp4,
663- IsAMixedPrecFp4, dotOp);
670+ /* allowTranspose=*/ !isAFP4,
671+ /* isMMAv5Fp4Padded=*/ isMMAv5Fp4PaddedLhs,
672+ /* forceTranspose=*/ !dotOp.getLhsKPack (),
673+ dotOp);
664674 b = getSharedMemoryMMAOperand (b, rewriter, 1 ,
665- /* allowTranspose=*/ !IsBMixedPrecFp4,
666- IsBMixedPrecFp4, dotOp);
675+ /* allowTranspose=*/ !isBFP4,
676+ /* isMMAv5Fp4Padded=*/ isMMAv5Fp4PaddedRhs,
677+ /* forceTranspose=*/ !dotOp.getRhsKPack (),
678+ dotOp);
667679
668680 MLIRContext *context = dotOp->getContext ();
669681 unsigned m = 128 ;
670682 unsigned n = retShapePerCTA[1 ] >= 256 ? 256 : retShapePerCTA[1 ];
671- unsigned k = 32 ;
672- // If both operands are E2M1, target the FP4 tensor core implicitly.
673- // This may result in a downstream compile-time error if the scaled TC
674- // descriptor requires options that are unavailable to the .kind=mxf4 mma.
675- // This is likely preferable over a silent runtime performance degradation
676- // from running f4xf4 via .kind=mxf8f6f4
677- if (dotOp.getAElemType () == ScaleDotElemType::E2M1 &&
678- dotOp.getBElemType () == ScaleDotElemType::E2M1) {
679- k = 64 ;
680- }
681- SmallVector<unsigned > instrShape = {m, n, k};
683+
682684 ArrayRef<unsigned > CTASplitNum = CTALayout.getCTASplitNum ();
683685 Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get (
684- context, instrShape[0 ], instrShape[1 ], /* unpacked=*/ true ,
685- CTASplitNum[0 ], CTASplitNum[1 ]);
686+ context, m, n, /* unpacked=*/ true , CTASplitNum[0 ], CTASplitNum[1 ]);
686687 Attribute tensorMemorySpace =
687688 triton::nvidia_gpu::TensorMemorySpaceAttr::get (context);
688689 Type accMemDescType = triton::gpu::MemDescType::get (
689690 oldRetType.getShape (), oldRetType.getElementType (), accEncoding,
690691 tensorMemorySpace,
691692 /* mutableMemory=*/ true );
692- Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout (
693- instrShape[ 0 ], instrShape[ 1 ] , oldRetType, numWarps);
693+ Attribute newDistributedEncoding =
694+ nvidia_gpu::getTmemCompatibleLayout (m, n , oldRetType, numWarps);
694695 auto newAccType = RankedTensorType::get (oldRetType.getShape (),
695696 oldRetType.getElementType (),
696697 newDistributedEncoding);
0 commit comments