@@ -176,8 +176,7 @@ struct LoadStoreConversionBase {
176176 return axisAnalysisPass.getMaskAlignment (mask);
177177 }
178178
179- std::optional<const std::string>
180- getAMDGPUMemScopeStr (MemSyncScope scope) const {
179+ std::optional<const char *> getAMDGPUMemScopeStr (MemSyncScope scope) const {
181180 // See: https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
182181 auto scopeStr = " " ;
183182 switch (scope) {
@@ -1295,44 +1294,43 @@ struct AtomicRMWOpConversion
12951294 // tt::atomicRmwOp(%ptr, %val, %mask):
12961295 // 0. Group thread by pairs. Master thread is (tid % 2 == 0);
12971296 // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
1298- // all the masters recieve value from secondary threads;
1297+ // all the masters receive value from secondary threads;
12991298 // 2. Take into account parity in the %mask value, build control flow
13001299 // structures according to it;
13011300 // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
13021301 // 4. All the threads send result of generated operation to (tid + 1) thread
1303- // via dppUpdateOp shl, so all secondary thread also recieve their
1302+ // via dppUpdateOp shl, so all secondary thread also receive their
13041303 // result.
13051304 //
13061305 // This approach enables us to use half the active threads committing atomic
13071306 // requests to avoid generating of code providing unified access to f16
1308- // element and reduce contantion .
1307+ // element and reduce contention .
13091308 bool useDppForPackedF16 = false ;
13101309 // tensor
13111310 if (tensorTy) {
13121311 auto valTy = cast<RankedTensorType>(val.getType ());
13131312 bool isF16Ty = valueElemTy.isF16 () || valueElemTy.isBF16 ();
13141313 unsigned availableVecSize = isF16Ty ? 2 : 1 ;
13151314 vec = std::min<unsigned >(vec, availableVecSize);
1316- // Force F16 packing in the case it's not comming in as packed, but the
1315+ // Force F16 packing in the case it's not coming in as packed, but the
13171316 // ISA can support packed atomic instructions.
13181317 useDppForPackedF16 =
13191318 supportsGlobalAtomicF16PackedAndDpp (targetInfo.getISAFamily ()) &&
1320- vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
1319+ vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD &&
1320+ !enableIntraWaveReduce;
13211321 // mask
13221322 numElems = tensorTy.getNumElements ();
13231323 }
1324- Value mask = b.int_val ( 1 , 1 );
1324+ Value mask = b.true_val ( );
13251325 auto tid = getThreadId (rewriter, loc);
13261326 mask = b.and_ (mask, b.icmp_slt (b.mul (tid, b.i32_val (elemsPerThread)),
13271327 b.i32_val (numElems)));
1328- if (useDppForPackedF16)
1329- mask = b.and_ (mask, b.icmp_eq (b.urem (tid, b.i32_val (2 )), b.i32_val (0 )));
13301328
13311329 auto memOrdering = op.getSem ();
13321330 auto scope = op.getScope ();
13331331 auto atomicMemOrdering = getMemoryOrdering (memOrdering);
13341332
1335- auto scopeStr = getAMDGPUMemScopeStr (scope);
1333+ std::optional< const char *> scopeStr = getAMDGPUMemScopeStr (scope);
13361334 if (!scopeStr)
13371335 return rewriter.notifyMatchFailure (op, " Unknown AMDGPU memory scope" );
13381336
@@ -1346,24 +1344,59 @@ struct AtomicRMWOpConversion
13461344 // elemsPerThread.
13471345 Value rmwMask = llMask ? b.and_ (mask, maskElements[i]) : mask;
13481346
1347+ Value i64Ones = b.i64_val (~uint64_t (0 ));
1348+ Value i64Zeros = b.i64_val (0 );
13491349 Value operand;
1350+ Value rightNeighbourPtr;
1351+ Value enablePackedOpt;
13501352 if (useDppForPackedF16) {
1353+ Value isOddI32 = b.urem (tid, b.i32_val (2 ));
1354+ // First check if odd threads hold adjacent ptrs to even ones.
1355+ Value castedAddr = b.ptrtoint (i64_ty, rmwPtr);
1356+ // Set casted addr to all ones if the thread is disabled.
1357+ castedAddr = b.select (rmwMask, castedAddr, i64Ones);
1358+
13511359 // Move %val to left neighbour to proceed packed atomic further.
13521360 Value packedVal = b.null (packF16Ty);
1353- packedVal = b. insert_element (packF16Ty, packedVal, valElements[i],
1354- b. i32_val ( 0 ) );
1355- // Pack to i32 type to simplify transaction
1361+ packedVal =
1362+ b. insert_element (packF16Ty, packedVal, valElements[i], isOddI32 );
1363+ // Pack to i32 type to simplify transaction.
13561364 packedVal = b.bitcast (packedVal, i32_ty);
1365+ // Zero operands for disabled threads to make addition no op.
1366+ packedVal = b.select (rmwMask, packedVal, b.i32_val (0 ));
13571367 Value dppMoveRes = shiftLeftI32ByDpp (rewriter, packedVal);
1368+
1369+ Value rightNeighbourAddr =
1370+ genI32TiledOp (rewriter, shiftLeftI32ByDpp, castedAddr);
1371+
1372+ // Packing optimization only supported if following conditions are true:
1373+ // 1. address is aligned by 4 bytes
1374+ // 2. right neighbour has adjacent address
1375+ // 3. both threads are active
1376+ Value isAligned =
1377+ b.icmp_eq (b.urem (castedAddr, b.i64_val (4 )), b.i64_val (0 ));
1378+ Value neighbourAddrAdjacent = b.icmp_eq (
1379+ rightNeighbourAddr,
1380+ b.add (castedAddr,
1381+ b.i64_val (valueElemTy.getIntOrFloatBitWidth () / 8 )));
1382+ Value neighbourEnabled = b.icmp_ne (i64Ones, rightNeighbourAddr);
1383+ Value bothEnabled = b.and_ (neighbourEnabled, rmwMask);
1384+ enablePackedOpt =
1385+ b.and_ (b.and_ (isAligned, bothEnabled), neighbourAddrAdjacent);
1386+
1387+ // Enable only the even threads.
1388+ Value anyEnabled = b.or_ (neighbourEnabled, rmwMask);
1389+ // If one of the threads is disabled, use the neighbour's addr.
1390+ rightNeighbourAddr =
1391+ b.select (neighbourEnabled, rightNeighbourAddr, castedAddr);
1392+ castedAddr = b.select (rmwMask, castedAddr, rightNeighbourAddr);
1393+
1394+ rmwMask = b.and_ (anyEnabled, b.icmp_eq (isOddI32, b.i32_val (0 )));
1395+
13581396 // Unpack results back
1359- Value unpackedDppRes = b.bitcast (dppMoveRes, packF16Ty);
1360- operand = b.undef (packF16Ty);
1361- operand =
1362- b.insert_element (packF16Ty, operand, valElements[i], b.i32_val (0 ));
1363- operand = b.insert_element (
1364- packF16Ty, operand,
1365- b.extract_element (valueElemTy, unpackedDppRes, b.i32_val (0 )),
1366- b.i32_val (1 ));
1397+ rightNeighbourPtr = b.inttoptr (rmwPtr.getType (), rightNeighbourAddr);
1398+ rmwPtr = b.inttoptr (rmwPtr.getType (), castedAddr);
1399+ operand = b.bitcast (b.or_ (packedVal, dppMoveRes), packF16Ty);
13671400 } else if (vec == 1 ) {
13681401 operand = valElements[i];
13691402 } else {
@@ -1388,15 +1421,47 @@ struct AtomicRMWOpConversion
13881421 rewriter.setInsertionPointToEnd (atomicBlock);
13891422 auto maybeKind = matchAtomicOp (atomicRmwAttr);
13901423 Value atom;
1424+ Value isVecOp;
13911425 if (enableIntraWaveReduce) {
13921426 atom = atomicIntraWaveReduce (rewriter, rmwPtr, operand, *maybeKind,
1393- atomicMemOrdering, scopeStr. value () );
1427+ atomicMemOrdering, * scopeStr);
13941428 } else {
1395- atom = rewriter
1396- .create <LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
1397- atomicMemOrdering,
1398- StringRef (scopeStr.value ()))
1399- .getResult ();
1429+ if (useDppForPackedF16) {
1430+ // Determine on the runtime what atomic intrinsic to execute:
1431+ // packed or regular.
1432+ auto *packedBlock =
1433+ atomicBlock->splitBlock (rewriter.getInsertionPoint ());
1434+ auto *regularBlock =
1435+ rewriter.createBlock (atomicBlock->getParent (),
1436+ std::next (Region::iterator (atomicBlock)));
1437+ rewriter.setInsertionPointToEnd (atomicBlock);
1438+ rewriter.create <LLVM::CondBrOp>(loc, enablePackedOpt, packedBlock,
1439+ regularBlock);
1440+
1441+ // Fill out the regular block, where we issue two atomic ops.
1442+ rewriter.setInsertionPointToEnd (regularBlock);
1443+ Value pairedOperand0 =
1444+ b.extract_element (valueElemTy, operand, b.i32_val (0 ));
1445+ Value pairedOperand1 =
1446+ b.extract_element (valueElemTy, operand, b.i32_val (1 ));
1447+ Value atomNonVec0 = rewriter.create <LLVM::AtomicRMWOp>(
1448+ loc, *maybeKind, rmwPtr, pairedOperand0, atomicMemOrdering,
1449+ *scopeStr);
1450+ Value atomNonVec1 = rewriter.create <LLVM::AtomicRMWOp>(
1451+ loc, *maybeKind, rightNeighbourPtr, pairedOperand1,
1452+ atomicMemOrdering, *scopeStr);
1453+ Value packedRes = b.undef (packF16Ty);
1454+ packedRes =
1455+ b.insert_element (packF16Ty, packedRes, atomNonVec0, b.i32_val (0 ));
1456+ packedRes =
1457+ b.insert_element (packF16Ty, packedRes, atomNonVec1, b.i32_val (1 ));
1458+ rewriter.create <LLVM::BrOp>(loc, packedRes, endBlock);
1459+
1460+ // Start to fill out the packed block.
1461+ rewriter.setInsertionPointToEnd (packedBlock);
1462+ }
1463+ atom = rewriter.create <LLVM::AtomicRMWOp>(
1464+ loc, *maybeKind, rmwPtr, operand, atomicMemOrdering, *scopeStr);
14001465 }
14011466
14021467 if (!tensorTy) {
@@ -1623,11 +1688,8 @@ struct AtomicRMWOpConversion
16231688 rewriter.setInsertionPointToEnd (leaderBlock);
16241689 // Utilize global atomic only by leader threads
16251690 rmwPtr = b.inttoptr (origPtrType, rmwPtr);
1626- Value atom = rewriter
1627- .create <LLVM::AtomicRMWOp>(loc, opKind, rmwPtr,
1628- afterRedBlock->getArgument (0 ),
1629- memOrdering, scope)
1630- .getResult ();
1691+ Value atom = rewriter.create <LLVM::AtomicRMWOp>(
1692+ loc, opKind, rmwPtr, afterRedBlock->getArgument (0 ), memOrdering, scope);
16311693 rewriter.create <LLVM::BrOp>(loc, atom, endBlock);
16321694 rewriter.setInsertionPointToStart (endBlock);
16331695
0 commit comments