Skip to content

Commit 7ec5895

Browse files
committed
[SeparateConstOffsetFromGEP] Decompose constant xor operand if possible
Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes the base for memory operations. This transformation is true under the following conditions Check 1 - B and C are disjoint. Check 2 - XOR(A,C) and B are disjoint. This transformation is beneficial particularly for GEPs because Disjoint OR operations often map better to addressing modes than XOR. This can enable further optimizations in the GEP offset folding pipeline
1 parent 3492929 commit 7ec5895

File tree

2 files changed

+297
-265
lines changed

2 files changed

+297
-265
lines changed

llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp

Lines changed: 172 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
#include "llvm/IR/Function.h"
175175
#include "llvm/IR/GetElementPtrTypeIterator.h"
176176
#include "llvm/IR/IRBuilder.h"
177+
#include "llvm/IR/InstIterator.h"
177178
#include "llvm/IR/InstrTypes.h"
178179
#include "llvm/IR/Instruction.h"
179180
#include "llvm/IR/Instructions.h"
@@ -190,6 +191,7 @@
190191
#include "llvm/Support/ErrorHandling.h"
191192
#include "llvm/Support/raw_ostream.h"
192193
#include "llvm/Transforms/Scalar.h"
194+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
193195
#include "llvm/Transforms/Utils/Local.h"
194196
#include <cassert>
195197
#include <cstdint>
@@ -491,6 +493,39 @@ class SeparateConstOffsetFromGEP {
491493
Value *tryFoldXorToOrDisjoint(Instruction &I);
492494
};
493495

496+
/// A helper class that aims to convert xor operations into or operations when
497+
/// their operands are disjoint and the result is used in a GEP's index. This
498+
/// can then enable further GEP optimizations by effectively turning BaseVal |
499+
/// Const into BaseVal + Const when they are disjoint, which
500+
/// SeparateConstOffsetFromGEP can then process. This is a common pattern that
501+
/// sets up a grid of memory accesses across a wave where each thread acesses
502+
/// data at various offsets.
503+
class XorToOrDisjointTransformer {
504+
public:
505+
XorToOrDisjointTransformer(Function &F, DominatorTree &DT,
506+
const DataLayout &DL)
507+
: F(F), DT(DT), DL(DL) {}
508+
509+
bool run();
510+
511+
private:
512+
Function &F;
513+
DominatorTree &DT;
514+
const DataLayout &DL;
515+
/// Maps a common operand to all Xor instructions
516+
using XorOpList = SmallVector<std::pair<BinaryOperator *, APInt>, 8>;
517+
using XorBaseValMap = DenseMap<Value *, XorOpList>;
518+
XorBaseValMap XorGroups;
519+
520+
/// Checks if the given value has at least one GetElementPtr user
521+
bool hasGEPUser(const Value *V) const;
522+
523+
/// Processes a group of XOR instructions that share the same non-constant
524+
/// base operand. Returns true if this group's processing modified the
525+
/// function.
526+
bool processXorGroup(Value *OriginalBaseVal, XorOpList &XorsInGroup);
527+
};
528+
494529
} // end anonymous namespace
495530

496531
char SeparateConstOffsetFromGEPLegacyPass::ID = 0;
@@ -1167,177 +1202,163 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
11671202
return true;
11681203
}
11691204

1170-
bool SeparateConstOffsetFromGEP::decomposeXor(Function &F) {
1171-
bool FunctionChanged = false;
1172-
SmallVector<std::pair<Instruction *, Value *>, 16> ReplacementsToMake;
1173-
1174-
for (BasicBlock &BB : F) {
1175-
for (Instruction &I : BB) {
1176-
if (I.getOpcode() == Instruction::Xor) {
1177-
if (Value *Replacement = tryFoldXorToOrDisjoint(I)) {
1178-
ReplacementsToMake.push_back({&I, Replacement});
1179-
FunctionChanged = true;
1180-
}
1181-
}
1205+
// Helper function to check if an instruction has at least one GEP user
1206+
bool XorToOrDisjointTransformer::hasGEPUser(const Value *V) const {
1207+
for (const User *U : V->users()) {
1208+
if (isa<GetElementPtrInst>(U)) {
1209+
return true;
11821210
}
11831211
}
1184-
1185-
if (!ReplacementsToMake.empty()) {
1186-
LLVM_DEBUG(dbgs() << "Applying " << ReplacementsToMake.size()
1187-
<< " XOR->OR Disjoint replacements in " << F.getName()
1188-
<< "\n");
1189-
for (auto &Pair : ReplacementsToMake)
1190-
Pair.first->replaceAllUsesWith(Pair.second);
1191-
1192-
for (auto &Pair : ReplacementsToMake)
1193-
Pair.first->eraseFromParent();
1194-
}
1195-
1196-
return FunctionChanged;
1212+
return false;
11971213
}
11981214

1199-
static llvm::Instruction *findClosestSequentialXor(Value *A, Instruction &I) {
1200-
llvm::Instruction *ClosestUser = nullptr;
1201-
for (llvm::User *User : A->users()) {
1202-
if (auto *UserInst = llvm::dyn_cast<llvm::Instruction>(User)) {
1203-
if (UserInst->getOpcode() != Instruction::Xor || UserInst == &I)
1204-
continue;
1205-
if (!ClosestUser)
1206-
ClosestUser = UserInst;
1207-
else {
1208-
// Compare instruction positions.
1209-
if (UserInst->comesBefore(ClosestUser)) {
1210-
ClosestUser = UserInst;
1211-
}
1212-
}
1213-
}
1214-
}
1215-
return ClosestUser;
1216-
}
1215+
bool XorToOrDisjointTransformer::processXorGroup(Value *OriginalBaseVal,
1216+
XorOpList &XorsInGroup) {
1217+
bool Changed = false;
1218+
if (XorsInGroup.size() <= 1)
1219+
return false;
12171220

1218-
/// Try to transform I = xor(A, C1) into or_disjoint(Y, C2)
1219-
/// where Y = xor(A, C0) is another existing instruction dominating I,
1220-
/// C2 = C1 - C0, and A is known to be disjoint with C2.
1221-
///
1222-
/// This transformation is beneficial particularly for GEPs because:
1223-
/// 1. OR operations often map better to addressing modes than XOR
1224-
/// 2. Disjoint OR operations preserve the semantics of the original XOR
1225-
/// 3. This can enable further optimizations in the GEP offset folding pipeline
1226-
///
1227-
/// @param I The XOR instruction being visited.
1228-
/// @return The replacement Value* if successful, nullptr otherwise.
1229-
Value *SeparateConstOffsetFromGEP::tryFoldXorToOrDisjoint(Instruction &I) {
1230-
assert(I.getOpcode() == Instruction::Xor && "Instruction must be XOR");
1231-
1232-
// Check if I has at least one GEP user.
1233-
bool HasGepUser = false;
1234-
for (User *U : I.users()) {
1235-
if (isa<GetElementPtrInst>(U)) {
1236-
HasGepUser = true;
1221+
// Sort XorsInGroup by the constant offset value in increasing order.
1222+
llvm::sort(
1223+
XorsInGroup.begin(), XorsInGroup.end(),
1224+
[](const auto &A, const auto &B) { return A.second.ult(B.second); });
1225+
1226+
// Dominance check
1227+
// The "base" XOR for dominance purposes is the one with the smallest
1228+
// constant.
1229+
BinaryOperator *XorWithSmallConst = XorsInGroup[0].first;
1230+
1231+
for (size_t i = 1; i < XorsInGroup.size(); ++i) {
1232+
BinaryOperator *currentXorToProcess = XorsInGroup[i].first;
1233+
1234+
// Check if the XorWithSmallConst dominates currentXorToProcess.
1235+
// If not, clone and add the instruction.
1236+
if (!DT.dominates(XorWithSmallConst, currentXorToProcess)) {
1237+
LLVM_DEBUG(
1238+
dbgs() << DEBUG_TYPE
1239+
<< ": Cloning and inserting XOR with smallest constant ("
1240+
<< *XorWithSmallConst << ") as it does not dominate "
1241+
<< *currentXorToProcess << " in function " << F.getName()
1242+
<< "\n");
1243+
1244+
BinaryOperator *ClonedXor =
1245+
cast<BinaryOperator>(XorWithSmallConst->clone());
1246+
ClonedXor->setName(XorWithSmallConst->getName() + ".dom_clone");
1247+
ClonedXor->insertAfter(dyn_cast<Instruction>(OriginalBaseVal));
1248+
LLVM_DEBUG(dbgs() << " Cloned Inst: " << *ClonedXor << "\n");
1249+
Changed = true;
1250+
XorWithSmallConst = ClonedXor;
12371251
break;
12381252
}
12391253
}
1240-
// If no user is a GEP instruction, abort the transformation.
1241-
if (!HasGepUser) {
1242-
LLVM_DEBUG(
1243-
dbgs() << "SeparateConstOffsetFromGEP: Skipping XOR->OR DISJOINT for"
1244-
<< I << " because it has no GEP users.\n");
1245-
return nullptr;
1246-
}
12471254

1248-
Value *Op0 = I.getOperand(0);
1249-
Value *Op1 = I.getOperand(1);
1250-
ConstantInt *C1 = dyn_cast<ConstantInt>(Op1);
1251-
Value *A = Op0;
1252-
1253-
// Bail out of there is not constant operand.
1254-
if (!C1) {
1255-
C1 = dyn_cast<ConstantInt>(Op0);
1256-
if (!C1)
1257-
return nullptr;
1258-
A = Op1;
1259-
}
1255+
SmallVector<Instruction *, 8> InstructionsToErase;
1256+
const APInt SmallestConst =
1257+
dyn_cast<ConstantInt>(XorWithSmallConst->getOperand(1))->getValue();
12601258

1261-
if (isa<UndefValue>(A))
1262-
return nullptr;
1259+
// Main transformation loop: Iterate over the original XORs in the sorted
1260+
// group.
1261+
for (const auto &XorEntry : XorsInGroup) {
1262+
BinaryOperator *XorInst = XorEntry.first; // Original XOR instruction
1263+
const APInt ConstOffsetVal = XorEntry.second;
12631264

1264-
APInt C1_APInt = C1->getValue();
1265-
unsigned BitWidth = C1_APInt.getBitWidth();
1266-
Type *Ty = I.getType();
1265+
// Do not process the one with smallest constant as it is the base.
1266+
if (XorInst == XorWithSmallConst)
1267+
continue;
12671268

1268-
// Find Dominating Y = xor A, C0
1269-
Instruction *FoundUserInst = nullptr;
1270-
APInt C0_APInt;
1269+
// Disjointness Check 1
1270+
APInt NewConstVal = ConstOffsetVal - SmallestConst;
1271+
if ((NewConstVal & SmallestConst) != 0) {
1272+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Cannot transform XOR in function "
1273+
<< F.getName() << ":\n"
1274+
<< " New Const: " << NewConstVal << "\n"
1275+
<< " Smallest Const: " << SmallestConst << "\n"
1276+
<< " are not disjoint \n");
1277+
continue;
1278+
}
12711279

1272-
// Find the closest XOR instruction using the same value.
1273-
Instruction *UserInst = findClosestSequentialXor(A, I);
1274-
if (!UserInst) {
1275-
LLVM_DEBUG(
1276-
dbgs() << "SeparateConstOffsetFromGEP: No dominating XOR found for" << I
1277-
<< "\n");
1278-
return nullptr;
1280+
// Disjointness Check 2
1281+
KnownBits KnownBaseBits(
1282+
XorWithSmallConst->getType()->getScalarSizeInBits());
1283+
computeKnownBits(XorWithSmallConst, KnownBaseBits, DL, 0, nullptr,
1284+
XorWithSmallConst, &DT);
1285+
if ((KnownBaseBits.Zero & NewConstVal) == NewConstVal) {
1286+
LLVM_DEBUG(dbgs() << DEBUG_TYPE
1287+
<< ": Transforming XOR to OR (disjoint) in function "
1288+
<< F.getName() << ":\n"
1289+
<< " Xor: " << *XorInst << "\n"
1290+
<< " Base Val: " << *XorWithSmallConst << "\n"
1291+
<< " New Const: " << NewConstVal << "\n");
1292+
1293+
auto *NewOrInst = BinaryOperator::CreateDisjointOr(
1294+
XorWithSmallConst,
1295+
ConstantInt::get(OriginalBaseVal->getType(), NewConstVal),
1296+
XorInst->getName() + ".or_disjoint", XorInst->getIterator());
1297+
1298+
NewOrInst->copyMetadata(*XorInst);
1299+
XorInst->replaceAllUsesWith(NewOrInst);
1300+
LLVM_DEBUG(dbgs() << " New Inst: " << *NewOrInst << "\n");
1301+
InstructionsToErase.push_back(XorInst); // Mark original XOR for deletion
1302+
1303+
Changed = true;
1304+
} else {
1305+
LLVM_DEBUG(
1306+
dbgs() << DEBUG_TYPE
1307+
<< ": Cannot transform XOR (not proven disjoint) in function "
1308+
<< F.getName() << ":\n"
1309+
<< " Xor: " << *XorInst << "\n"
1310+
<< " Base Val: " << *XorWithSmallConst << "\n"
1311+
<< " New Const: " << NewConstVal << "\n");
1312+
}
12791313
}
1314+
if (!InstructionsToErase.empty())
1315+
for (Instruction *I : InstructionsToErase)
1316+
I->eraseFromParent();
12801317

1281-
BinaryOperator *UserBO = cast<BinaryOperator>(UserInst);
1282-
Value *UserOp0 = UserBO->getOperand(0);
1283-
Value *UserOp1 = UserBO->getOperand(1);
1284-
ConstantInt *UserC = nullptr;
1285-
if (UserOp0 == A)
1286-
UserC = dyn_cast<ConstantInt>(UserOp1);
1287-
else if (UserOp1 == A)
1288-
UserC = dyn_cast<ConstantInt>(UserOp0);
1289-
else {
1290-
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1291-
<< " doesn't use value " << *A << "\n");
1292-
return nullptr;
1293-
}
1318+
return Changed;
1319+
}
12941320

1295-
if (!UserC) {
1296-
LLVM_DEBUG(
1297-
dbgs()
1298-
<< "SeparateConstOffsetFromGEP: Found XOR doesn't have constant operand"
1299-
<< *UserInst << "\n");
1300-
return nullptr;
1301-
}
1321+
// Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes
1322+
// the base for memory operations. This transformation is true under the
1323+
// following conditions
1324+
// Check 1 - B and C are disjoint.
1325+
// Check 2 - XOR(A,C) and B are disjoint.
1326+
//
1327+
// This transformation is beneficial particularly for GEPs because:
1328+
// 1. OR operations often map better to addressing modes than XOR
1329+
// 2. Disjoint OR operations preserve the semantics of the original XOR
1330+
// 3. This can enable further optimizations in the GEP offset folding pipeline
1331+
bool XorToOrDisjointTransformer::run() {
1332+
bool Changed = false;
13021333

1303-
if (!DT->dominates(UserInst, &I)) {
1304-
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1305-
<< " doesn't dominate " << I << "\n");
1306-
return nullptr;
1334+
// Collect all candidate XORs
1335+
for (Instruction &I : instructions(F)) {
1336+
if (auto *XorOp = dyn_cast<BinaryOperator>(&I)) {
1337+
if (XorOp->getOpcode() == Instruction::Xor) {
1338+
Value *Op0 = XorOp->getOperand(0);
1339+
ConstantInt *C1 = nullptr;
1340+
// Match: xor Op0, Constant
1341+
if (match(XorOp->getOperand(1), m_ConstantInt(C1))) {
1342+
if (hasGEPUser(XorOp)) {
1343+
XorGroups[Op0].push_back({XorOp, C1->getValue()});
1344+
}
1345+
}
1346+
}
1347+
}
13071348
}
13081349

1309-
FoundUserInst = UserInst;
1310-
C0_APInt = UserC->getValue();
1311-
1312-
// Calculate C2 = C1 - C0.
1313-
APInt C2_APInt = C1_APInt - C0_APInt;
1314-
1315-
// Check Disjointness A & C2 == 0.
1316-
KnownBits KnownA(BitWidth);
1317-
computeKnownBits(A, KnownA, *DL, 0, nullptr, &I, DT);
1350+
if (XorGroups.empty())
1351+
return false;
13181352

1319-
if ((KnownA.One & C2_APInt) != 0) {
1320-
LLVM_DEBUG(
1321-
dbgs() << "SeparateConstOffsetFromGEP: Disjointness check failed for"
1322-
<< I << "\n");
1323-
return nullptr;
1353+
// Process each group of XORs
1354+
for (auto &GroupPair : XorGroups) {
1355+
Value *OriginalBaseVal = GroupPair.first;
1356+
XorOpList &XorsInGroup = GroupPair.second;
1357+
if (processXorGroup(OriginalBaseVal, XorsInGroup))
1358+
Changed = true;
13241359
}
13251360

1326-
IRBuilder<> Builder(&I);
1327-
Builder.SetInsertPoint(&I);
1328-
Constant *C2_Const = ConstantInt::get(Ty, C2_APInt);
1329-
Twine Name = I.getName();
1330-
Value *NewOr = BinaryOperator::CreateDisjointOr(FoundUserInst, C2_Const, Name,
1331-
I.getIterator());
1332-
// Preserve metadata
1333-
if (Instruction *NewOrInst = dyn_cast<Instruction>(NewOr))
1334-
NewOrInst->copyMetadata(I);
1335-
1336-
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Replacing" << I
1337-
<< " (used by GEP) with" << *NewOr << " based on"
1338-
<< *FoundUserInst << "\n");
1339-
1340-
return NewOr;
1361+
return Changed;
13411362
}
13421363

13431364
bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) {
@@ -1361,7 +1382,8 @@ bool SeparateConstOffsetFromGEP::run(Function &F) {
13611382
bool Changed = false;
13621383

13631384
// Decompose xor in to "or disjoint" if possible.
1364-
Changed |= decomposeXor(F);
1385+
XorToOrDisjointTransformer XorTransformer(F, *DT, *DL);
1386+
Changed |= XorTransformer.run();
13651387

13661388
for (BasicBlock &B : F) {
13671389
if (!DT->isReachableFromEntry(&B))

0 commit comments

Comments
 (0)