174174#include " llvm/IR/Function.h"
175175#include " llvm/IR/GetElementPtrTypeIterator.h"
176176#include " llvm/IR/IRBuilder.h"
177+ #include " llvm/IR/InstIterator.h"
177178#include " llvm/IR/InstrTypes.h"
178179#include " llvm/IR/Instruction.h"
179180#include " llvm/IR/Instructions.h"
190191#include " llvm/Support/ErrorHandling.h"
191192#include " llvm/Support/raw_ostream.h"
192193#include " llvm/Transforms/Scalar.h"
194+ #include " llvm/Transforms/Utils/BasicBlockUtils.h"
193195#include " llvm/Transforms/Utils/Local.h"
194196#include < cassert>
195197#include < cstdint>
@@ -491,6 +493,39 @@ class SeparateConstOffsetFromGEP {
491493 Value *tryFoldXorToOrDisjoint (Instruction &I);
492494};
493495
496+ // / A helper class that aims to convert xor operations into or operations when
497+ // / their operands are disjoint and the result is used in a GEP's index. This
498+ // / can then enable further GEP optimizations by effectively turning BaseVal |
499+ // / Const into BaseVal + Const when they are disjoint, which
500+ // / SeparateConstOffsetFromGEP can then process. This is a common pattern that
501+ // / sets up a grid of memory accesses across a wave where each thread acesses
502+ // / data at various offsets.
503+ class XorToOrDisjointTransformer {
504+ public:
505+ XorToOrDisjointTransformer (Function &F, DominatorTree &DT,
506+ const DataLayout &DL)
507+ : F(F), DT(DT), DL(DL) {}
508+
509+ bool run ();
510+
511+ private:
512+ Function &F;
513+ DominatorTree &DT;
514+ const DataLayout &DL;
515+ // / Maps a common operand to all Xor instructions
516+ using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8 >;
517+ using XorBaseValMap = DenseMap<Value *, XorOpList>;
518+ XorBaseValMap XorGroups;
519+
520+ // / Checks if the given value has at least one GetElementPtr user
521+ bool hasGEPUser (const Value *V) const ;
522+
523+ // / Processes a group of XOR instructions that share the same non-constant
524+ // / base operand. Returns true if this group's processing modified the
525+ // / function.
526+ bool processXorGroup (Value *OriginalBaseVal, XorOpList &XorsInGroup);
527+ };
528+
494529} // end anonymous namespace
495530
496531char SeparateConstOffsetFromGEPLegacyPass::ID = 0 ;
@@ -1167,177 +1202,163 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
11671202 return true ;
11681203}
11691204
1170- bool SeparateConstOffsetFromGEP::decomposeXor (Function &F) {
1171- bool FunctionChanged = false ;
1172- SmallVector<std::pair<Instruction *, Value *>, 16 > ReplacementsToMake;
1173-
1174- for (BasicBlock &BB : F) {
1175- for (Instruction &I : BB) {
1176- if (I.getOpcode () == Instruction::Xor) {
1177- if (Value *Replacement = tryFoldXorToOrDisjoint (I)) {
1178- ReplacementsToMake.push_back ({&I, Replacement});
1179- FunctionChanged = true ;
1180- }
1181- }
1205+ // Helper function to check if an instruction has at least one GEP user
1206+ bool XorToOrDisjointTransformer::hasGEPUser (const Value *V) const {
1207+ for (const User *U : V->users ()) {
1208+ if (isa<GetElementPtrInst>(U)) {
1209+ return true ;
11821210 }
11831211 }
1184-
1185- if (!ReplacementsToMake.empty ()) {
1186- LLVM_DEBUG (dbgs () << " Applying " << ReplacementsToMake.size ()
1187- << " XOR->OR Disjoint replacements in " << F.getName ()
1188- << " \n " );
1189- for (auto &Pair : ReplacementsToMake)
1190- Pair.first ->replaceAllUsesWith (Pair.second );
1191-
1192- for (auto &Pair : ReplacementsToMake)
1193- Pair.first ->eraseFromParent ();
1194- }
1195-
1196- return FunctionChanged;
1212+ return false ;
11971213}
11981214
1199- static llvm::Instruction *findClosestSequentialXor (Value *A, Instruction &I) {
1200- llvm::Instruction *ClosestUser = nullptr ;
1201- for (llvm::User *User : A->users ()) {
1202- if (auto *UserInst = llvm::dyn_cast<llvm::Instruction>(User)) {
1203- if (UserInst->getOpcode () != Instruction::Xor || UserInst == &I)
1204- continue ;
1205- if (!ClosestUser)
1206- ClosestUser = UserInst;
1207- else {
1208- // Compare instruction positions.
1209- if (UserInst->comesBefore (ClosestUser)) {
1210- ClosestUser = UserInst;
1211- }
1212- }
1213- }
1214- }
1215- return ClosestUser;
1216- }
1215+ bool XorToOrDisjointTransformer::processXorGroup (Value *OriginalBaseVal,
1216+ XorOpList &XorsInGroup) {
1217+ bool Changed = false ;
1218+ if (XorsInGroup.size () <= 1 )
1219+ return false ;
12171220
1218- // / Try to transform I = xor(A, C1) into or_disjoint(Y, C2)
1219- // / where Y = xor(A, C0) is another existing instruction dominating I,
1220- // / C2 = C1 - C0, and A is known to be disjoint with C2.
1221- // /
1222- // / This transformation is beneficial particularly for GEPs because:
1223- // / 1. OR operations often map better to addressing modes than XOR
1224- // / 2. Disjoint OR operations preserve the semantics of the original XOR
1225- // / 3. This can enable further optimizations in the GEP offset folding pipeline
1226- // /
1227- // / @param I The XOR instruction being visited.
1228- // / @return The replacement Value* if successful, nullptr otherwise.
1229- Value *SeparateConstOffsetFromGEP::tryFoldXorToOrDisjoint (Instruction &I) {
1230- assert (I.getOpcode () == Instruction::Xor && " Instruction must be XOR" );
1231-
1232- // Check if I has at least one GEP user.
1233- bool HasGepUser = false ;
1234- for (User *U : I.users ()) {
1235- if (isa<GetElementPtrInst>(U)) {
1236- HasGepUser = true ;
1221+ // Sort XorsInGroup by the constant offset value in increasing order.
1222+ llvm::sort (
1223+ XorsInGroup.begin (), XorsInGroup.end (),
1224+ [](const auto &A, const auto &B) { return A.second .ult (B.second ); });
1225+
1226+ // Dominance check
1227+ // The "base" XOR for dominance purposes is the one with the smallest
1228+ // constant.
1229+ BinaryOperator *XorWithSmallConst = XorsInGroup[0 ].first ;
1230+
1231+ for (size_t i = 1 ; i < XorsInGroup.size (); ++i) {
1232+ BinaryOperator *currentXorToProcess = XorsInGroup[i].first ;
1233+
1234+ // Check if the XorWithSmallConst dominates currentXorToProcess.
1235+ // If not, clone and add the instruction.
1236+ if (!DT.dominates (XorWithSmallConst, currentXorToProcess)) {
1237+ LLVM_DEBUG (
1238+ dbgs () << DEBUG_TYPE
1239+ << " : Cloning and inserting XOR with smallest constant ("
1240+ << *XorWithSmallConst << " ) as it does not dominate "
1241+ << *currentXorToProcess << " in function " << F.getName ()
1242+ << " \n " );
1243+
1244+ BinaryOperator *ClonedXor =
1245+ cast<BinaryOperator>(XorWithSmallConst->clone ());
1246+ ClonedXor->setName (XorWithSmallConst->getName () + " .dom_clone" );
1247+ ClonedXor->insertAfter (dyn_cast<Instruction>(OriginalBaseVal));
1248+ LLVM_DEBUG (dbgs () << " Cloned Inst: " << *ClonedXor << " \n " );
1249+ Changed = true ;
1250+ XorWithSmallConst = ClonedXor;
12371251 break ;
12381252 }
12391253 }
1240- // If no user is a GEP instruction, abort the transformation.
1241- if (!HasGepUser) {
1242- LLVM_DEBUG (
1243- dbgs () << " SeparateConstOffsetFromGEP: Skipping XOR->OR DISJOINT for"
1244- << I << " because it has no GEP users.\n " );
1245- return nullptr ;
1246- }
12471254
1248- Value *Op0 = I.getOperand (0 );
1249- Value *Op1 = I.getOperand (1 );
1250- ConstantInt *C1 = dyn_cast<ConstantInt>(Op1);
1251- Value *A = Op0;
1252-
1253- // Bail out of there is not constant operand.
1254- if (!C1) {
1255- C1 = dyn_cast<ConstantInt>(Op0);
1256- if (!C1)
1257- return nullptr ;
1258- A = Op1;
1259- }
1255+ SmallVector<Instruction *, 8 > InstructionsToErase;
1256+ const APInt SmallestConst =
1257+ dyn_cast<ConstantInt>(XorWithSmallConst->getOperand (1 ))->getValue ();
12601258
1261- if (isa<UndefValue>(A))
1262- return nullptr ;
1259+ // Main transformation loop: Iterate over the original XORs in the sorted
1260+ // group.
1261+ for (const auto &XorEntry : XorsInGroup) {
1262+ BinaryOperator *XorInst = XorEntry.first ; // Original XOR instruction
1263+ const APInt ConstOffsetVal = XorEntry.second ;
12631264
1264- APInt C1_APInt = C1-> getValue ();
1265- unsigned BitWidth = C1_APInt. getBitWidth ();
1266- Type *Ty = I. getType () ;
1265+ // Do not process the one with smallest constant as it is the base.
1266+ if (XorInst == XorWithSmallConst)
1267+ continue ;
12671268
1268- // Find Dominating Y = xor A, C0
1269- Instruction *FoundUserInst = nullptr ;
1270- APInt C0_APInt;
1269+ // Disjointness Check 1
1270+ APInt NewConstVal = ConstOffsetVal - SmallestConst;
1271+ if ((NewConstVal & SmallestConst) != 0 ) {
1272+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Cannot transform XOR in function "
1273+ << F.getName () << " :\n "
1274+ << " New Const: " << NewConstVal << " \n "
1275+ << " Smallest Const: " << SmallestConst << " \n "
1276+ << " are not disjoint \n " );
1277+ continue ;
1278+ }
12711279
1272- // Find the closest XOR instruction using the same value.
1273- Instruction *UserInst = findClosestSequentialXor (A, I);
1274- if (!UserInst) {
1275- LLVM_DEBUG (
1276- dbgs () << " SeparateConstOffsetFromGEP: No dominating XOR found for" << I
1277- << " \n " );
1278- return nullptr ;
1280+ // Disjointness Check 2
1281+ KnownBits KnownBaseBits (
1282+ XorWithSmallConst->getType ()->getScalarSizeInBits ());
1283+ computeKnownBits (XorWithSmallConst, KnownBaseBits, DL, 0 , nullptr ,
1284+ XorWithSmallConst, &DT);
1285+ if ((KnownBaseBits.Zero & NewConstVal) == NewConstVal) {
1286+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
1287+ << " : Transforming XOR to OR (disjoint) in function "
1288+ << F.getName () << " :\n "
1289+ << " Xor: " << *XorInst << " \n "
1290+ << " Base Val: " << *XorWithSmallConst << " \n "
1291+ << " New Const: " << NewConstVal << " \n " );
1292+
1293+ auto *NewOrInst = BinaryOperator::CreateDisjointOr (
1294+ XorWithSmallConst,
1295+ ConstantInt::get (OriginalBaseVal->getType (), NewConstVal),
1296+ XorInst->getName () + " .or_disjoint" , XorInst->getIterator ());
1297+
1298+ NewOrInst->copyMetadata (*XorInst);
1299+ XorInst->replaceAllUsesWith (NewOrInst);
1300+ LLVM_DEBUG (dbgs () << " New Inst: " << *NewOrInst << " \n " );
1301+ InstructionsToErase.push_back (XorInst); // Mark original XOR for deletion
1302+
1303+ Changed = true ;
1304+ } else {
1305+ LLVM_DEBUG (
1306+ dbgs () << DEBUG_TYPE
1307+ << " : Cannot transform XOR (not proven disjoint) in function "
1308+ << F.getName () << " :\n "
1309+ << " Xor: " << *XorInst << " \n "
1310+ << " Base Val: " << *XorWithSmallConst << " \n "
1311+ << " New Const: " << NewConstVal << " \n " );
1312+ }
12791313 }
1314+ if (!InstructionsToErase.empty ())
1315+ for (Instruction *I : InstructionsToErase)
1316+ I->eraseFromParent ();
12801317
1281- BinaryOperator *UserBO = cast<BinaryOperator>(UserInst);
1282- Value *UserOp0 = UserBO->getOperand (0 );
1283- Value *UserOp1 = UserBO->getOperand (1 );
1284- ConstantInt *UserC = nullptr ;
1285- if (UserOp0 == A)
1286- UserC = dyn_cast<ConstantInt>(UserOp1);
1287- else if (UserOp1 == A)
1288- UserC = dyn_cast<ConstantInt>(UserOp0);
1289- else {
1290- LLVM_DEBUG (dbgs () << " SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1291- << " doesn't use value " << *A << " \n " );
1292- return nullptr ;
1293- }
1318+ return Changed;
1319+ }
12941320
1295- if (!UserC) {
1296- LLVM_DEBUG (
1297- dbgs ()
1298- << " SeparateConstOffsetFromGEP: Found XOR doesn't have constant operand"
1299- << *UserInst << " \n " );
1300- return nullptr ;
1301- }
1321+ // Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
1322+ // the base for memory operations. This transformation is true under the
1323+ // following conditions
1324+ // Check 1 - B and C are disjoint.
1325+ // Check 2 - XOR(A,C) and B are disjoint.
1326+ //
1327+ // This transformation is beneficial particularly for GEPs because:
1328+ // 1. OR operations often map better to addressing modes than XOR
1329+ // 2. Disjoint OR operations preserve the semantics of the original XOR
1330+ // 3. This can enable further optimizations in the GEP offset folding pipeline
1331+ bool XorToOrDisjointTransformer::run () {
1332+ bool Changed = false ;
13021333
1303- if (!DT->dominates (UserInst, &I)) {
1304- LLVM_DEBUG (dbgs () << " SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1305- << " doesn't dominate " << I << " \n " );
1306- return nullptr ;
1334+ // Collect all candidate XORs
1335+ for (Instruction &I : instructions (F)) {
1336+ if (auto *XorOp = dyn_cast<BinaryOperator>(&I)) {
1337+ if (XorOp->getOpcode () == Instruction::Xor) {
1338+ Value *Op0 = XorOp->getOperand (0 );
1339+ ConstantInt *C1 = nullptr ;
1340+ // Match: xor Op0, Constant
1341+ if (match (XorOp->getOperand (1 ), m_ConstantInt (C1))) {
1342+ if (hasGEPUser (XorOp)) {
1343+ XorGroups[Op0].push_back ({XorOp, C1->getValue ()});
1344+ }
1345+ }
1346+ }
1347+ }
13071348 }
13081349
1309- FoundUserInst = UserInst;
1310- C0_APInt = UserC->getValue ();
1311-
1312- // Calculate C2 = C1 - C0.
1313- APInt C2_APInt = C1_APInt - C0_APInt;
1314-
1315- // Check Disjointness A & C2 == 0.
1316- KnownBits KnownA (BitWidth);
1317- computeKnownBits (A, KnownA, *DL, 0 , nullptr , &I, DT);
1350+ if (XorGroups.empty ())
1351+ return false ;
13181352
1319- if ((KnownA.One & C2_APInt) != 0 ) {
1320- LLVM_DEBUG (
1321- dbgs () << " SeparateConstOffsetFromGEP: Disjointness check failed for"
1322- << I << " \n " );
1323- return nullptr ;
1353+ // Process each group of XORs
1354+ for (auto &GroupPair : XorGroups) {
1355+ Value *OriginalBaseVal = GroupPair.first ;
1356+ XorOpList &XorsInGroup = GroupPair.second ;
1357+ if (processXorGroup (OriginalBaseVal, XorsInGroup))
1358+ Changed = true ;
13241359 }
13251360
1326- IRBuilder<> Builder (&I);
1327- Builder.SetInsertPoint (&I);
1328- Constant *C2_Const = ConstantInt::get (Ty, C2_APInt);
1329- Twine Name = I.getName ();
1330- Value *NewOr = BinaryOperator::CreateDisjointOr (FoundUserInst, C2_Const, Name,
1331- I.getIterator ());
1332- // Preserve metadata
1333- if (Instruction *NewOrInst = dyn_cast<Instruction>(NewOr))
1334- NewOrInst->copyMetadata (I);
1335-
1336- LLVM_DEBUG (dbgs () << " SeparateConstOffsetFromGEP: Replacing" << I
1337- << " (used by GEP) with" << *NewOr << " based on"
1338- << *FoundUserInst << " \n " );
1339-
1340- return NewOr;
1361+ return Changed;
13411362}
13421363
13431364bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction (Function &F) {
@@ -1361,7 +1382,8 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {
13611382 bool Changed = false ;
13621383
13631384 // Decompose xor in to "or disjoint" if possible.
1364- Changed |= decomposeXor (F);
1385+ XorToOrDisjointTransformer XorTransformer (F, *DT, *DL);
1386+ Changed |= XorTransformer.run ();
13651387
13661388 for (BasicBlock &B : F) {
13671389 if (!DT->isReachableFromEntry (&B))
0 commit comments