@@ -694,6 +694,32 @@ struct AtomicCASOpConversion
694694 }
695695};
696696
697+ bool supportsGlobalAtomicF16PackedAndDpp (triton::AMD::ISAFamily isaFamily) {
698+ return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
699+ isaFamily == triton::AMD::ISAFamily::CDNA2 ||
700+ isaFamily == triton::AMD::ISAFamily::CDNA3;
701+ }
702+
703+ Value generateI32DppMove (PatternRewriter &rewriter, Value val, int dppCtrl) {
704+ assert (val.getType ().isInteger (32 ));
705+ auto loc = val.getLoc ();
706+ Value old = i32_val (0 );
707+ int rowMask = 0b1111 ; // enable all rows
708+ int bankMask = 0b1111 ; // enable all banks
709+ bool boundCtrl = false ;
710+ auto dppMovOp = rewriter.create <ROCDL::DPPUpdateOp>(
711+ loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl);
712+ return dppMovOp.getResult ();
713+ }
714+
715+ Value shiftLeftI32ByDpp (PatternRewriter &rewriter, Value val) {
716+ return generateI32DppMove (rewriter, val, 0x101 ); // shift left 1 lane
717+ }
718+
719+ Value shiftRightI32ByDpp (PatternRewriter &rewriter, Value val) {
720+ return generateI32DppMove (rewriter, val, 0x111 ); // shift right 1 lane
721+ }
722+
697723struct AtomicRMWOpConversion
698724 : public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
699725 public LoadStoreConversionBase {
@@ -765,27 +791,52 @@ struct AtomicRMWOpConversion
765791 // vec = 1, numElements = 1 for scalar
766792 auto vec = getVectorSize (ptr);
767793 int numElems = 1 ;
794+ Type packF16Ty = vec_ty (valueElemTy, 2 );
795+
796+ // In the case of unpaired f16 elements utilize dpp instructions to
797+ // accelerate atomics. Here is an algorithm of lowering
798+ // tt::atomicRmwOp(%ptr, %val, %mask):
799+ // 0. Group thread by pairs. Master thread is (tid % 2 == 0);
800+ // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
801+ // all the masters recieve value from secondary threads;
802+ // 2. Take into account parity in the %mask value, build control flow
803+ // structures according to it;
804+ // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
805+ // 4. All the threads send result of generated operation to (tid + 1) thread
806+ // via dppUpdateOp shl, so all secondary thread also recieve their
807+ // result.
808+ //
809+ // This approach enables us to use half the active threads committing atomic
810+ // requests to avoid generating of code providing unified access to f16
811+ // element and reduce contantion.
812+ bool useDppForPackedF16 = false ;
768813 // tensor
769814 if (tensorTy) {
770815 auto valTy = cast<RankedTensorType>(val.getType ());
771- Type elTy = valTy.getElementType ();
772- vec = std::min<unsigned >(vec, llvm::isa<FloatType>(elTy) &&
773- elTy.getIntOrFloatBitWidth () == 16
774- ? 2
775- : 1 );
816+ bool isF16Ty = valueElemTy.isF16 () || valueElemTy.isBF16 ();
817+ unsigned availableVecSize = isF16Ty ? 2 : 1 ;
818+ vec = std::min<unsigned >(vec, availableVecSize);
819+ // Force F16 packing in the case it's not comming in as packed, but the
820+ // ISA can support packed atomic instructions.
821+ useDppForPackedF16 =
822+ supportsGlobalAtomicF16PackedAndDpp (targetInfo.getISAFamily ()) &&
823+ vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
776824 // mask
777825 numElems = tensorTy.getNumElements ();
778826 }
779827 Value mask = int_val (1 , 1 );
780828 auto tid = tid_val ();
781829 mask = and_ (mask,
782830 icmp_slt (mul (tid, i32_val (elemsPerThread)), i32_val (numElems)));
831+ if (useDppForPackedF16)
832+ mask = and_ (mask, icmp_eq (urem (tid, i32_val (2 )), i32_val (0 )));
783833
784834 auto memOrdering = op.getSem ();
785835 auto atomicMemOrdering = getMemoryOrdering (memOrdering);
786836
787837 auto vecTy = vec_ty (valueElemTy, vec);
788838 auto retType = vec == 1 ? valueElemTy : vecTy;
839+ retType = useDppForPackedF16 ? packF16Ty : retType;
789840 SmallVector<Value> resultVals (elemsPerThread);
790841 for (size_t i = 0 ; i < elemsPerThread; i += vec) {
791842 Value rmwPtr = ptrElements[i];
@@ -794,7 +845,24 @@ struct AtomicRMWOpConversion
794845 Value rmwMask = llMask ? and_ (mask, maskElements[i]) : mask;
795846
796847 Value operand;
797- if (vec == 1 ) {
848+ if (useDppForPackedF16) {
849+ // Move %val to left neighbour to proceed packed atomic further.
850+ Value packedVal = null (packF16Ty);
851+ packedVal =
852+ insert_element (packF16Ty, packedVal, valElements[i], i32_val (0 ));
853+ // Pack to i32 type to simplify transaction
854+ packedVal = bitcast (packedVal, i32_ty);
855+ Value dppMoveRes = shiftLeftI32ByDpp (rewriter, packedVal);
856+ // Unpack results back
857+ Value unpackedDppRes = bitcast (dppMoveRes, packF16Ty);
858+ operand = undef (packF16Ty);
859+ operand =
860+ insert_element (packF16Ty, operand, valElements[i], i32_val (0 ));
861+ operand = insert_element (
862+ packF16Ty, operand,
863+ extract_element (valueElemTy, unpackedDppRes, i32_val (0 )),
864+ i32_val (1 ));
865+ } else if (vec == 1 ) {
798866 operand = valElements[i];
799867 } else {
800868 operand = undef (vecTy);
@@ -836,10 +904,25 @@ struct AtomicRMWOpConversion
836904 rewriter.setInsertionPointToStart (endBlock);
837905 Value retVal = endBlock->getArgument (0 );
838906 if (tensorTy) {
839- for (int ii = 0 ; ii < vec; ++ii) {
840- resultVals[i + ii] =
841- vec == 1 ? retVal
842- : extract_element (valueElemTy, retVal, i32_val (ii));
907+ if (useDppForPackedF16) {
908+ // Return packed to i32 result after atomic operation back from master
909+ // lane.
910+ auto packedRet = bitcast (retVal, i32_ty);
911+ Value dppMovRes = shiftRightI32ByDpp (rewriter, packedRet);
912+ // Unpack results back
913+ Value unpackedDppRes = bitcast (dppMovRes, packF16Ty);
914+ retVal = insert_element (
915+ packF16Ty, retVal,
916+ extract_element (valueElemTy, unpackedDppRes, i32_val (1 )),
917+ i32_val (1 ));
918+ resultVals[i] =
919+ extract_element (valueElemTy, retVal, urem (tid, i32_val (2 )));
920+ } else {
921+ for (int ii = 0 ; ii < vec; ++ii) {
922+ resultVals[i + ii] =
923+ vec == 1 ? retVal
924+ : extract_element (valueElemTy, retVal, i32_val (ii));
925+ }
843926 }
844927 } else {
845928 if (!atomicNeedsSharedMemory (op.getResult ())) {
0 commit comments