@@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
9898 return mask;
9999}
100100
101+ std::string getRegisterSizeCode (int size, bool is_float) {
102+ switch (size) {
103+ case 1 :
104+ return " b" ;
105+ case 16 :
106+ return " h" ;
107+ case 32 :
108+ return is_float ? " f" : " r" ;
109+ case 64 :
110+ return is_float ? " d" : " l" ;
111+ case 128 :
112+ return " q" ;
113+ default :
114+ llvm_unreachable (" Unsupported register size" );
115+ }
116+ }
117+
101118// Contains some helper functions for both Load and Store conversions.
102119struct LoadStoreConversionBase {
103120 explicit LoadStoreConversionBase (const NVIDIA::TargetInfo &targetInfo,
@@ -632,6 +649,20 @@ struct AtomicRMWOpConversion
632649 : ConvertOpToLLVMPattern<triton::AtomicRMWOp>(converter, benefit),
633650 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
634651
652+ bool supportsVectorized (Operation *moduleOp, RMWOp opType,
653+ Type elementType) const {
654+ // vectorized atomics are only supported on hopper,
655+ // and only for specific atomic ops (add, min, max).
656+ // Note that "packed types" like f16x2 are supported sm60+.
657+ auto computeCapability = getNVIDIAComputeCapability (moduleOp);
658+ if (computeCapability < 90 ) {
659+ return false ;
660+ }
661+
662+ return opType == RMWOp::FADD &&
663+ (elementType.isF16 () || elementType.isBF16 () || elementType.isF32 ());
664+ }
665+
635666 LogicalResult
636667 matchAndRewrite (triton::AtomicRMWOp op, OpAdaptor adaptor,
637668 ConversionPatternRewriter &rewriter) const override {
@@ -664,45 +695,82 @@ struct AtomicRMWOpConversion
664695 : valueTy;
665696 const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth ();
666697 auto elemsPerThread = getTotalElemsPerThread (val.getType ());
667- // vec = 1, numElements = 1 for scalar
668- auto vec = getVectorSize (ptr);
669- auto vecOrig = vec;
670- int numElems = 1 ;
671- // tensor
698+ // packed: e.g. packed=2 for f16x2
699+ // vec: e.g. .v2, .v4, .v8 version of atom instruction.
700+ unsigned vec, vecOrig;
701+ int numElems, packed;
672702 if (tensorTy) {
703+ vec = getVectorSize (ptr);
704+ if (llMask) {
705+ vec = std::min<unsigned >(vec, getMaskAlignment (op.getMask ()));
706+ }
707+ vecOrig = vec;
708+ packed = 1 ;
673709 auto valTy = cast<RankedTensorType>(val.getType ());
674- vec = std::min<unsigned >(vec, valTy.getElementType ().isF16 () ? 2 : 1 );
675- // mask
710+ if (!supportsVectorized (moduleOp, atomicRmwAttr,
711+ valTy.getElementType ())) {
712+ packed =
713+ std::min<unsigned >(vecOrig, valTy.getElementType ().isF16 () ? 2 : 1 );
714+ vec = 1 ;
715+ }
676716 numElems = tensorTy.getNumElements ();
717+ } else {
718+ // scalar
719+ vec = 1 ;
720+ vecOrig = 1 ;
721+ numElems = 1 ;
722+ packed = 1 ;
677723 }
724+ assert ((packed == 1 || vec == 1 ) && " packed or vec must be 1" );
678725
679- if (vec == 1 && numElems > 1 )
726+ if (vec * packed == 1 && numElems > 1 )
680727 op->emitRemark () << " Warning: vectorization fails vec = " << vec
681- << " origin vec = " << vecOrig
728+ << " packed = " << packed << " origin vec = " << vecOrig
682729 << " numElems = " << numElems;
683730
684731 Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
685732
686- auto vecTy = vec_ty (valueElemTy, vec );
733+ auto packedTy = vec_ty (valueElemTy, packed );
687734 SmallVector<Value> resultVals (elemsPerThread);
688- for (size_t i = 0 ; i < elemsPerThread; i += vec) {
689- Value rmwVal = undef (vecTy);
690- for (int ii = 0 ; ii < vec; ++ii) {
691- Value iiVal = createIndexAttrConstant (
692- rewriter, loc, getTypeConverter ()->getIndexType (), ii);
693- rmwVal = insert_element (vecTy, rmwVal, valElements[i + ii], iiVal);
694- }
695-
735+ for (size_t i = 0 ; i < elemsPerThread; i += vec * packed) {
696736 Value rmwPtr = ptrElements[i];
697737 Value rmwMask = llMask ? and_ (mask, maskElements[i]) : mask;
698738 std::string sTy ;
699739 PTXBuilder ptxBuilderAtomicRMW;
700- std::string tyId = valueElemNBits * vec == 64
701- ? " l"
702- : (valueElemNBits * vec == 32 ? " r" : " h" );
703- auto *dstOpr = ptxBuilderAtomicRMW.newOperand (" =" + tyId, /* init=*/ true );
740+ // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
741+ std::string tyId =
742+ getRegisterSizeCode (valueElemNBits * packed, /* is_float=*/ false );
743+
744+ PTXBuilder::Operand *dstOpr;
745+ if (vec > 1 ) {
746+ dstOpr = ptxBuilderAtomicRMW.newListOperand ();
747+ for (unsigned ii = 0 ; ii < vec; ++ii) {
748+ dstOpr->listAppend (
749+ ptxBuilderAtomicRMW.newOperand (" =" + tyId, /* init=*/ true ));
750+ }
751+ } else {
752+ dstOpr = ptxBuilderAtomicRMW.newOperand (" =" + tyId, /* init=*/ true );
753+ }
754+
704755 auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand (rmwPtr, " l" );
705- auto *valOpr = ptxBuilderAtomicRMW.newOperand (rmwVal, tyId);
756+
757+ PTXBuilder::Operand *valOpr;
758+ if (vec > 1 ) {
759+ valOpr = ptxBuilderAtomicRMW.newListOperand ();
760+ for (unsigned ii = 0 ; ii < vec; ++ii) {
761+ valOpr->listAppend (
762+ ptxBuilderAtomicRMW.newOperand (valElements[i + ii], tyId));
763+ }
764+ } else if (packed > 1 ) {
765+ Value rmwVal = undef (packedTy);
766+ for (int ii = 0 ; ii < packed; ++ii) {
767+ rmwVal = insert_element (packedTy, rmwVal, valElements[i + ii],
768+ i32_val (ii));
769+ }
770+ valOpr = ptxBuilderAtomicRMW.newOperand (rmwVal, tyId);
771+ } else {
772+ valOpr = ptxBuilderAtomicRMW.newOperand (valElements[i], tyId);
773+ }
706774
707775 auto scope = stringifyMemSyncScope (op.getScope ()).str ();
708776 auto &atom = ptxBuilderAtomicRMW.create <>(" atom" )->global ().o (scope);
@@ -725,7 +793,7 @@ struct AtomicRMWOpConversion
725793 rmwOp = " add" ;
726794 rmwOp += (valueElemNBits == 16 ? " .noftz" : " " );
727795 sTy = " f" + sBits ;
728- sTy += (vec == 2 && valueElemNBits == 16 ) ? " x2" : " " ;
796+ sTy += (packed == 2 && valueElemNBits == 16 ) ? " x2" : " " ;
729797 break ;
730798 case RMWOp::MAX:
731799 sTy = " s" + sBits ;
@@ -750,15 +818,33 @@ struct AtomicRMWOpConversion
750818 std::string semStr;
751819 llvm::raw_string_ostream os (semStr);
752820 os << op.getSem ();
753- atom.o (semStr).o (rmwOp).o (sTy );
821+ atom.o (semStr).o (rmwOp).v (vec). o (sTy );
754822 if (tensorTy) {
755823 atom (dstOpr, ptrOpr, valOpr).predicate (rmwMask);
756- auto retType = vec == 1 ? valueElemTy : vecTy;
824+ Type retType;
825+ if (vec > 1 ) {
826+ SmallVector<Type> retTys (vec, valueElemTy);
827+ retType = struct_ty (retTys);
828+ } else if (packed > 1 ) {
829+ retType = packedTy;
830+ } else {
831+ retType = valueElemTy;
832+ }
833+
757834 auto ret = ptxBuilderAtomicRMW.launch (rewriter, loc, retType);
758- for (int ii = 0 ; ii < vec; ++ii) {
759- resultVals[i + ii] =
760- vec == 1 ? ret : extract_element (valueElemTy, ret, i32_val (ii));
835+
836+ if (vec > 1 ) {
837+ for (unsigned ii = 0 ; ii < vec; ++ii) {
838+ resultVals[i + ii] = extract_val (valueElemTy, ret, ii);
839+ }
840+ } else if (packed > 1 ) {
841+ for (unsigned ii = 0 ; ii < packed; ++ii) {
842+ resultVals[i + ii] = extract_element (valueElemTy, ret, i32_val (ii));
843+ }
844+ } else {
845+ resultVals[i] = ret;
761846 }
847+
762848 } else {
763849 auto ASMReturnTy = void_ty (ctx);
764850 atom (dstOpr, ptrOpr, valOpr).predicate (rmwMask);
0 commit comments