@@ -264,17 +264,17 @@ static void createScaledGen5MMA(ConversionPatternRewriter &rewriter,
264264 MemDescOperand a, Value b, MemDescOperand d,
265265 Value scaleA, Value scaleB, Value pred,
266266 Value instDescriptor, Value useInitAcc,
267- bool aInTmem, mxfpKind mxfpInstKind) {
267+ bool aInTmem, mxfpKind mxfpInstKind,
268+ bool twoCTAs) {
268269 PTXBuilder ptxBuilder;
269- std::string opcode;
270+ std::string opcode =
271+ " tcgen05.mma.cta_group::" + std::to_string (twoCTAs ? 2 : 1 ) + " .kind::" ;
270272 if (mxfpInstKind == mxfpKind::mxf8f6f4) {
271- opcode =
272- " tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X" ;
273+ opcode += " mxf8f6f4.block_scale.scale_vec::1X" ;
273274 } else if (mxfpInstKind == mxfpKind::mxf4) {
274- opcode = " tcgen05.mma.cta_group::1.kind:: mxf4.block_scale.scale_vec::2X" ;
275+ opcode + = " mxf4.block_scale.scale_vec::2X" ;
275276 } else if (mxfpInstKind == mxfpKind::mxf4nvf4) {
276- opcode =
277- " tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X" ;
277+ opcode += " mxf4nvf4.block_scale.scale_vec::4X" ;
278278 } else {
279279 assert (0 && " Unsupported mxfp kind." );
280280 }
@@ -312,7 +312,9 @@ static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc,
312312 " tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::"
313313 " cluster.multicast::cluster.b64 [$1], $2;" ;
314314 } else {
315- opcode = " @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];" ;
315+ opcode =
316+ " @$0 tcgen05.commit.cta_group::" + std::to_string (twoCTAs ? 2 : 1 ) +
317+ " .mbarrier::arrive::one.b64 [$1];" ;
316318 }
317319 auto &barrierOp = *ptxBuilder.create (opcode);
318320 barrierOp (ptxOperands, /* onlyAttachMLIRArgs=*/ true );
@@ -486,7 +488,8 @@ void convertDot(const LLVMTypeConverter &typeConverter,
486488 MemDescType bTensorTy = op.getB ().getType ();
487489 MemDescType dTensorTy = op.getD ().getType ();
488490 auto dLayout = cast<ttng::TensorMemoryEncodingAttr>(dTensorTy.getEncoding ());
489- bool twoCTAs = op.getTwoCtas ();
491+ bool twoCTAs = ttng::getModuleTwoCTAs (op);
492+ assert (twoCTAs == op.getTwoCtas ());
490493
491494 DotConversion dot;
492495
@@ -595,6 +598,7 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
595598 Value baseD = tb.ptrtoint (i32_ty, adaptor.getD ());
596599 Value baseScaleA = tb.ptrtoint (i32_ty, adaptor.getAScale ());
597600 Value baseScaleB = tb.ptrtoint (i32_ty, adaptor.getBScale ());
601+ bool twoCTAs = ttng::getModuleTwoCTAs (op);
598602
599603 int numRows = 128 ;
600604 int colSizeInBits = 32 ;
@@ -634,14 +638,13 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
634638 subWordIdx, subWordIdx, mxfpInstKind);
635639 createScaledGen5MMA (rewriter, loc, op, a, b, accAddress, scaleA, scaleB,
636640 pred, instDescriptor, useInitAcc, desc.aInTmem ,
637- mxfpInstKind);
641+ mxfpInstKind, twoCTAs );
638642 };
639643
640644 convertDotImpl (typeConverter, rewriter, loc, op.getA (), op.getB (),
641645 adaptor.getA (), adaptor.getB (), dTensorTy, adaptor.getUseD (),
642646 adaptor.getPred (), adaptor.getBarriers (),
643- adaptor.getBarrierPreds (), /* twoCTAs=*/ false , opKindIsMXFP4,
644- dot);
647+ adaptor.getBarrierPreds (), twoCTAs, opKindIsMXFP4, dot);
645648}
646649
647650// ===----------------------------------------------------------------------===//
@@ -699,7 +702,7 @@ struct TCGen5CommitOpConversion
699702 pred = b.and_ (adaptor.getPred (), pred);
700703
701704 createMMACommit (rewriter, op.getLoc (), smemObj.getBase (), pred,
702- op. getTwoCtas ( ));
705+ ttng::getModuleTwoCTAs (op ));
703706 rewriter.eraseOp (op);
704707 return success ();
705708 }
0 commit comments