@@ -714,6 +714,32 @@ struct AtomicCASOpConversion
714714 }
715715};
716716
717+ bool supportsGlobalAtomicF16PackedAndDpp (triton::AMD::ISAFamily isaFamily) {
718+ return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
719+ isaFamily == triton::AMD::ISAFamily::CDNA2 ||
720+ isaFamily == triton::AMD::ISAFamily::CDNA3;
721+ }
722+
723+ Value generateI32DppMove (PatternRewriter &rewriter, Value val, int dppCtrl) {
724+ assert (val.getType ().isInteger (32 ));
725+ auto loc = val.getLoc ();
726+ Value old = i32_val (0 );
727+ int rowMask = 0b1111 ; // enable all rows
728+ int bankMask = 0b1111 ; // enable all banks
729+ bool boundCtrl = false ;
730+ auto dppMovOp = rewriter.create <ROCDL::DPPUpdateOp>(
731+ loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl);
732+ return dppMovOp.getResult ();
733+ }
734+
735+ Value shiftLeftI32ByDpp (PatternRewriter &rewriter, Value val) {
736+ return generateI32DppMove (rewriter, val, 0x101 ); // shift left 1 lane
737+ }
738+
739+ Value shiftRightI32ByDpp (PatternRewriter &rewriter, Value val) {
740+ return generateI32DppMove (rewriter, val, 0x111 ); // shift right 1 lane
741+ }
742+
717743struct AtomicRMWOpConversion
718744 : public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
719745 public LoadStoreConversionBase {
@@ -785,27 +811,52 @@ struct AtomicRMWOpConversion
785811 // vec = 1, numElements = 1 for scalar
786812 auto vec = getVectorSize (ptr);
787813 int numElems = 1 ;
814+ Type packF16Ty = vec_ty (valueElemTy, 2 );
815+
816+ // In the case of unpaired f16 elements utilize dpp instructions to
817+ // accelerate atomics. Here is an algorithm of lowering
818+ // tt::atomicRmwOp(%ptr, %val, %mask):
819+ // 0. Group thread by pairs. Master thread is (tid % 2 == 0);
820+ // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
821+ // all the masters recieve value from secondary threads;
822+ // 2. Take into account parity in the %mask value, build control flow
823+ // structures according to it;
824+ // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
825+ // 4. All the threads send result of generated operation to (tid + 1) thread
826+ // via dppUpdateOp shl, so all secondary thread also recieve their
827+ // result.
828+ //
829+ // This approach enables us to use half the active threads committing atomic
830+ // requests to avoid generating of code providing unified access to f16
831+ // element and reduce contantion.
832+ bool useDppForPackedF16 = false ;
788833 // tensor
789834 if (tensorTy) {
790835 auto valTy = cast<RankedTensorType>(val.getType ());
791- Type elTy = valTy.getElementType ();
792- vec = std::min<unsigned >(vec, llvm::isa<FloatType>(elTy) &&
793- elTy.getIntOrFloatBitWidth () == 16
794- ? 2
795- : 1 );
836+ bool isF16Ty = valueElemTy.isF16 () || valueElemTy.isBF16 ();
837+ unsigned availableVecSize = isF16Ty ? 2 : 1 ;
838+ vec = std::min<unsigned >(vec, availableVecSize);
839+ // Force F16 packing in the case it's not comming in as packed, but the
840+ // ISA can support packed atomic instructions.
841+ useDppForPackedF16 =
842+ supportsGlobalAtomicF16PackedAndDpp (targetInfo.getISAFamily ()) &&
843+ vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
796844 // mask
797845 numElems = tensorTy.getNumElements ();
798846 }
799847 Value mask = int_val (1 , 1 );
800848 auto tid = tid_val ();
801849 mask = and_ (mask,
802850 icmp_slt (mul (tid, i32_val (elemsPerThread)), i32_val (numElems)));
851+ if (useDppForPackedF16)
852+ mask = and_ (mask, icmp_eq (urem (tid, i32_val (2 )), i32_val (0 )));
803853
804854 auto memOrdering = op.getSem ();
805855 auto atomicMemOrdering = getMemoryOrdering (memOrdering);
806856
807857 auto vecTy = vec_ty (valueElemTy, vec);
808858 auto retType = vec == 1 ? valueElemTy : vecTy;
859+ retType = useDppForPackedF16 ? packF16Ty : retType;
809860 SmallVector<Value> resultVals (elemsPerThread);
810861 for (size_t i = 0 ; i < elemsPerThread; i += vec) {
811862 Value rmwPtr = ptrElements[i];
@@ -814,7 +865,24 @@ struct AtomicRMWOpConversion
814865 Value rmwMask = llMask ? and_ (mask, maskElements[i]) : mask;
815866
816867 Value operand;
817- if (vec == 1 ) {
868+ if (useDppForPackedF16) {
869+ // Move %val to left neighbour to proceed packed atomic further.
870+ Value packedVal = null (packF16Ty);
871+ packedVal =
872+ insert_element (packF16Ty, packedVal, valElements[i], i32_val (0 ));
873+ // Pack to i32 type to simplify transaction
874+ packedVal = bitcast (packedVal, i32_ty);
875+ Value dppMoveRes = shiftLeftI32ByDpp (rewriter, packedVal);
876+ // Unpack results back
877+ Value unpackedDppRes = bitcast (dppMoveRes, packF16Ty);
878+ operand = undef (packF16Ty);
879+ operand =
880+ insert_element (packF16Ty, operand, valElements[i], i32_val (0 ));
881+ operand = insert_element (
882+ packF16Ty, operand,
883+ extract_element (valueElemTy, unpackedDppRes, i32_val (0 )),
884+ i32_val (1 ));
885+ } else if (vec == 1 ) {
818886 operand = valElements[i];
819887 } else {
820888 operand = undef (vecTy);
@@ -856,10 +924,25 @@ struct AtomicRMWOpConversion
856924 rewriter.setInsertionPointToStart (endBlock);
857925 Value retVal = endBlock->getArgument (0 );
858926 if (tensorTy) {
859- for (int ii = 0 ; ii < vec; ++ii) {
860- resultVals[i + ii] =
861- vec == 1 ? retVal
862- : extract_element (valueElemTy, retVal, i32_val (ii));
927+ if (useDppForPackedF16) {
928+ // Return packed to i32 result after atomic operation back from master
929+ // lane.
930+ auto packedRet = bitcast (retVal, i32_ty);
931+ Value dppMovRes = shiftRightI32ByDpp (rewriter, packedRet);
932+ // Unpack results back
933+ Value unpackedDppRes = bitcast (dppMovRes, packF16Ty);
934+ retVal = insert_element (
935+ packF16Ty, retVal,
936+ extract_element (valueElemTy, unpackedDppRes, i32_val (1 )),
937+ i32_val (1 ));
938+ resultVals[i] =
939+ extract_element (valueElemTy, retVal, urem (tid, i32_val (2 )));
940+ } else {
941+ for (int ii = 0 ; ii < vec; ++ii) {
942+ resultVals[i + ii] =
943+ vec == 1 ? retVal
944+ : extract_element (valueElemTy, retVal, i32_val (ii));
945+ }
863946 }
864947 } else {
865948 if (!atomicNeedsSharedMemory (op.getResult ())) {
0 commit comments