174174#include " llvm/IR/Function.h"
175175#include " llvm/IR/GetElementPtrTypeIterator.h"
176176#include " llvm/IR/IRBuilder.h"
177+ #include " llvm/IR/InstIterator.h"
177178#include " llvm/IR/InstrTypes.h"
178179#include " llvm/IR/Instruction.h"
179180#include " llvm/IR/Instructions.h"
190191#include " llvm/Support/ErrorHandling.h"
191192#include " llvm/Support/raw_ostream.h"
192193#include " llvm/Transforms/Scalar.h"
194+ #include " llvm/Transforms/Utils/BasicBlockUtils.h"
193195#include " llvm/Transforms/Utils/Local.h"
194196#include < cassert>
195197#include < cstdint>
198200using namespace llvm ;
199201using namespace llvm ::PatternMatch;
200202
203+ #define DEBUG_TYPE " separate-offset-gep"
204+
201205static cl::opt<bool > DisableSeparateConstOffsetFromGEP (
202206 " disable-separate-const-offset-from-gep" , cl::init(false ),
203207 cl::desc(" Do not separate the constant offset from a GEP instruction" ),
@@ -488,6 +492,42 @@ class SeparateConstOffsetFromGEP {
488492 DenseMap<ExprKey, SmallVector<Instruction *, 2 >> DominatingSubs;
489493};
490494
495+ // / A helper class that aims to convert xor operations into or operations when
496+ // / their operands are disjoint and the result is used in a GEP's index. This
497+ // / can then enable further GEP optimizations by effectively turning BaseVal |
498+ // / Const into BaseVal + Const when they are disjoint, which
499+ // / SeparateConstOffsetFromGEP can then process. This is a common pattern that
500+ // / sets up a grid of memory accesses across a wave where each thread acesses
501+ // / data at various offsets.
502+ class XorToOrDisjointTransformer {
503+ public:
504+ XorToOrDisjointTransformer (Function &F, DominatorTree &DT,
505+ const DataLayout &DL)
506+ : F(F), DT(DT), DL(DL) {}
507+
508+ bool run ();
509+
510+ private:
511+ Function &F;
512+ DominatorTree &DT;
513+ const DataLayout &DL;
514+ // / Maps a common operand to all Xor instructions
515+ using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8 >;
516+ using XorBaseValInst = DenseMap<Instruction *, XorOpList>;
517+ XorBaseValInst XorGroups;
518+
519+ // / Checks if the given value has at least one GetElementPtr user
520+ static bool hasGEPUser (const Value *V);
521+
522+ // / Helper function to check if BaseXor dominates all XORs in the group
523+ bool dominatesAllXors (BinaryOperator *BaseXor, const XorOpList &XorsInGroup);
524+
525+ // / Processes a group of XOR instructions that share the same non-constant
526+ // / base operand. Returns true if this group's processing modified the
527+ // / function.
528+ bool processXorGroup (Instruction *OriginalBaseInst, XorOpList &XorsInGroup);
529+ };
530+
491531} // end anonymous namespace
492532
493533char SeparateConstOffsetFromGEPLegacyPass::ID = 0 ;
@@ -1223,6 +1263,154 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
12231263 return true ;
12241264}
12251265
1266+ // Helper function to check if an instruction has at least one GEP user
1267+ bool XorToOrDisjointTransformer::hasGEPUser (const Value *V) {
1268+ return llvm::any_of (V->users (), [](const User *U) {
1269+ return isa<llvm::GetElementPtrInst>(U);
1270+ });
1271+ }
1272+
1273+ bool XorToOrDisjointTransformer::dominatesAllXors (
1274+ BinaryOperator *BaseXor, const XorOpList &XorsInGroup) {
1275+ return llvm::all_of (XorsInGroup, [&](const auto &XorEntry) {
1276+ BinaryOperator *XorInst = XorEntry.first ;
1277+ // Do not evaluate the BaseXor, otherwise we end up cloning it.
1278+ return XorInst == BaseXor || DT.dominates (BaseXor, XorInst);
1279+ });
1280+ }
1281+
1282+ bool XorToOrDisjointTransformer::processXorGroup (Instruction *OriginalBaseInst,
1283+ XorOpList &XorsInGroup) {
1284+ bool Changed = false ;
1285+ if (XorsInGroup.size () <= 1 )
1286+ return false ;
1287+
1288+ // Sort XorsInGroup by the constant offset value in increasing order.
1289+ llvm::sort (XorsInGroup, [](const auto &A, const auto &B) {
1290+ return A.second .slt (B.second );
1291+ });
1292+
1293+ // Dominance check
1294+ // The "base" XOR for dominance purposes is the one with the smallest
1295+ // constant.
1296+ BinaryOperator *XorWithSmallConst = XorsInGroup[0 ].first ;
1297+
1298+ if (!dominatesAllXors (XorWithSmallConst, XorsInGroup)) {
1299+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
1300+ << " : Cloning and inserting XOR with smallest constant ("
1301+ << *XorWithSmallConst
1302+ << " ) as it does not dominate all other XORs"
1303+ << " in function " << F.getName () << " \n " );
1304+
1305+ BinaryOperator *ClonedXor =
1306+ cast<BinaryOperator>(XorWithSmallConst->clone ());
1307+ ClonedXor->setName (XorWithSmallConst->getName () + " .dom_clone" );
1308+ ClonedXor->insertAfter (OriginalBaseInst);
1309+ LLVM_DEBUG (dbgs () << " Cloned Inst: " << *ClonedXor << " \n " );
1310+ Changed = true ;
1311+ XorWithSmallConst = ClonedXor;
1312+ }
1313+
1314+ SmallVector<Instruction *, 8 > InstructionsToErase;
1315+ const APInt SmallestConst =
1316+ cast<ConstantInt>(XorWithSmallConst->getOperand (1 ))->getValue ();
1317+
1318+ // Main transformation loop: Iterate over the original XORs in the sorted
1319+ // group.
1320+ for (const auto &XorEntry : XorsInGroup) {
1321+ BinaryOperator *XorInst = XorEntry.first ; // Original XOR instruction
1322+ const APInt ConstOffsetVal = XorEntry.second ;
1323+
1324+ // Do not process the one with smallest constant as it is the base.
1325+ if (XorInst == XorWithSmallConst)
1326+ continue ;
1327+
1328+ // Disjointness Check 1
1329+ APInt NewConstVal = ConstOffsetVal - SmallestConst;
1330+ if ((NewConstVal & SmallestConst) != 0 ) {
1331+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Cannot transform XOR in function "
1332+ << F.getName () << " :\n "
1333+ << " New Const: " << NewConstVal
1334+ << " Smallest Const: " << SmallestConst
1335+ << " are not disjoint \n " );
1336+ continue ;
1337+ }
1338+
1339+ // Disjointness Check 2
1340+ if (MaskedValueIsZero (XorWithSmallConst, NewConstVal, SimplifyQuery (DL),
1341+ 0 )) {
1342+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
1343+ << " : Transforming XOR to OR (disjoint) in function "
1344+ << F.getName () << " :\n "
1345+ << " Xor: " << *XorInst << " \n "
1346+ << " Base Val: " << *XorWithSmallConst << " \n "
1347+ << " New Const: " << NewConstVal << " \n " );
1348+
1349+ auto *NewOrInst = BinaryOperator::CreateDisjointOr (
1350+ XorWithSmallConst,
1351+ ConstantInt::get (OriginalBaseInst->getType (), NewConstVal),
1352+ XorInst->getName () + " .or_disjoint" , XorInst->getIterator ());
1353+
1354+ NewOrInst->copyMetadata (*XorInst);
1355+ XorInst->replaceAllUsesWith (NewOrInst);
1356+ LLVM_DEBUG (dbgs () << " New Inst: " << *NewOrInst << " \n " );
1357+ InstructionsToErase.push_back (XorInst); // Mark original XOR for deletion
1358+
1359+ Changed = true ;
1360+ } else {
1361+ LLVM_DEBUG (
1362+ dbgs () << DEBUG_TYPE
1363+ << " : Cannot transform XOR (not proven disjoint) in function "
1364+ << F.getName () << " :\n "
1365+ << " Xor: " << *XorInst << " \n "
1366+ << " Base Val: " << *XorWithSmallConst << " \n "
1367+ << " New Const: " << NewConstVal << " \n " );
1368+ }
1369+ }
1370+
1371+ for (Instruction *I : InstructionsToErase)
1372+ I->eraseFromParent ();
1373+
1374+ return Changed;
1375+ }
1376+
1377+ // Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
1378+ // the base for memory operations. This transformation is true under the
1379+ // following conditions
1380+ // Check 1 - B and C are disjoint.
1381+ // Check 2 - XOR(A,C) and B are disjoint.
1382+ //
1383+ // This transformation is beneficial particularly for GEPs because:
1384+ // 1. OR operations often map better to addressing modes than XOR
1385+ // 2. Disjoint OR operations preserve the semantics of the original XOR
1386+ // 3. This can enable further optimizations in the GEP offset folding pipeline
1387+ bool XorToOrDisjointTransformer::run () {
1388+ bool Changed = false ;
1389+
1390+ // Collect all candidate XORs
1391+ for (Instruction &I : instructions (F)) {
1392+ Instruction *Op0 = nullptr ;
1393+ ConstantInt *C1 = nullptr ;
1394+ BinaryOperator *MatchedXorOp = nullptr ;
1395+
1396+ // Attempt to match the instruction 'I' as XOR operation.
1397+ if (match (&I, m_CombineAnd (m_Xor (m_Instruction (Op0), m_ConstantInt (C1)),
1398+ m_BinOp (MatchedXorOp))) &&
1399+ hasGEPUser (MatchedXorOp))
1400+ XorGroups[Op0].emplace_back (MatchedXorOp, C1->getValue ());
1401+ }
1402+
1403+ if (XorGroups.empty ())
1404+ return false ;
1405+
1406+ // Process each group of XORs
1407+ for (auto &[OriginalBaseInst, XorsInGroup] : XorGroups)
1408+ if (processXorGroup (OriginalBaseInst, XorsInGroup))
1409+ Changed = true ;
1410+
1411+ return Changed;
1412+ }
1413+
12261414bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction (Function &F) {
12271415 if (skipFunction (F))
12281416 return false ;
@@ -1242,6 +1430,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {
12421430
12431431 DL = &F.getDataLayout ();
12441432 bool Changed = false ;
1433+
1434+ // Decompose xor in to "or disjoint" if possible.
1435+ XorToOrDisjointTransformer XorTransformer (F, *DT, *DL);
1436+ Changed |= XorTransformer.run ();
1437+
12451438 for (BasicBlock &B : F) {
12461439 if (!DT->isReachableFromEntry (&B))
12471440 continue ;
0 commit comments