Skip to content

Commit def3b3e

Browse files
committed
Push correctness fixes
1 parent e45f0aa commit def3b3e

File tree

2 files changed

+221
-42
lines changed

2 files changed

+221
-42
lines changed

llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@
160160
#include "llvm/ADT/DenseMap.h"
161161
#include "llvm/ADT/DepthFirstIterator.h"
162162
#include "llvm/ADT/SmallVector.h"
163-
#include "llvm/Analysis/AssumptionCache.h"
164163
#include "llvm/Analysis/LoopInfo.h"
165164
#include "llvm/Analysis/MemoryBuiltins.h"
166165
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -1187,12 +1186,11 @@ bool SeparateConstOffsetFromGEP::decomposeXor(Function &F) {
11871186
LLVM_DEBUG(dbgs() << "Applying " << ReplacementsToMake.size()
11881187
<< " XOR->OR Disjoint replacements in " << F.getName()
11891188
<< "\n");
1190-
for (auto &Pair : ReplacementsToMake) {
1189+
for (auto &Pair : ReplacementsToMake)
11911190
Pair.first->replaceAllUsesWith(Pair.second);
1192-
}
1193-
for (auto &Pair : ReplacementsToMake) {
1191+
1192+
for (auto &Pair : ReplacementsToMake)
11941193
Pair.first->eraseFromParent();
1195-
}
11961194
}
11971195

11981196
return FunctionChanged;
@@ -1204,9 +1202,9 @@ static llvm::Instruction *findClosestSequentialXor(Value *A, Instruction &I) {
12041202
if (auto *UserInst = llvm::dyn_cast<llvm::Instruction>(User)) {
12051203
if (UserInst->getOpcode() != Instruction::Xor || UserInst == &I)
12061204
continue;
1207-
if (!ClosestUser) {
1205+
if (!ClosestUser)
12081206
ClosestUser = UserInst;
1209-
} else {
1207+
else {
12101208
// Compare instruction positions.
12111209
if (UserInst->comesBefore(ClosestUser)) {
12121210
ClosestUser = UserInst;
@@ -1217,9 +1215,14 @@ static llvm::Instruction *findClosestSequentialXor(Value *A, Instruction &I) {
12171215
return ClosestUser;
12181216
}
12191217

1220-
/// Try to transform I = xor(A, C1) into or disjoint(Y, C2)
1218+
/// Try to transform I = xor(A, C1) into or_disjoint(Y, C2)
12211219
/// where Y = xor(A, C0) is another existing instruction dominating I,
1222-
/// C2 = C0 ^ C1, and A is known to be disjoint with C2.
1220+
/// C2 = C1 - C0, and A is known to be disjoint with C2.
1221+
///
1222+
/// This transformation is beneficial particularly for GEPs because:
1223+
/// 1. OR operations often map better to addressing modes than XOR
1224+
/// 2. Disjoint OR operations preserve the semantics of the original XOR
1225+
/// 3. This can enable further optimizations in the GEP offset folding pipeline
12231226
///
12241227
/// @param I The XOR instruction being visited.
12251228
/// @return The replacement Value* if successful, nullptr otherwise.
@@ -1237,7 +1240,7 @@ Value *SeparateConstOffsetFromGEP::tryFoldXorToOrDisjoint(Instruction &I) {
12371240
// If no user is a GEP instruction, abort the transformation.
12381241
if (!HasGepUser) {
12391242
LLVM_DEBUG(
1240-
dbgs() << "SeparateConstOffsetFromGEP: Skipping XOR->OR DISJOINT for "
1243+
dbgs() << "SeparateConstOffsetFromGEP: Skipping XOR->OR DISJOINT for"
12411244
<< I << " because it has no GEP users.\n");
12421245
return nullptr;
12431246
}
@@ -1262,11 +1265,18 @@ Value *SeparateConstOffsetFromGEP::tryFoldXorToOrDisjoint(Instruction &I) {
12621265
unsigned BitWidth = C1_APInt.getBitWidth();
12631266
Type *Ty = I.getType();
12641267

1265-
// --- Step 2: Find Dominating Y = xor A, C0 ---
1266-
Instruction *FoundUserInst = nullptr; // Instruction Y
1268+
// Find Dominating Y = xor A, C0
1269+
Instruction *FoundUserInst = nullptr;
12671270
APInt C0_APInt;
12681271

1269-
auto UserInst = findClosestSequentialXor(A, I);
1272+
// Find the closest XOR instruction using the same value.
1273+
Instruction *UserInst = findClosestSequentialXor(A, I);
1274+
if (!UserInst) {
1275+
LLVM_DEBUG(
1276+
dbgs() << "SeparateConstOffsetFromGEP: No dominating XOR found for" << I
1277+
<< "\n");
1278+
return nullptr;
1279+
}
12701280

12711281
BinaryOperator *UserBO = cast<BinaryOperator>(UserInst);
12721282
Value *UserOp0 = UserBO->getOperand(0);
@@ -1276,51 +1286,57 @@ Value *SeparateConstOffsetFromGEP::tryFoldXorToOrDisjoint(Instruction &I) {
12761286
UserC = dyn_cast<ConstantInt>(UserOp1);
12771287
else if (UserOp1 == A)
12781288
UserC = dyn_cast<ConstantInt>(UserOp0);
1279-
if (UserC) {
1280-
if (DT->dominates(UserInst, &I)) {
1281-
FoundUserInst = UserInst;
1282-
C0_APInt = UserC->getValue();
1283-
}
1289+
else {
1290+
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1291+
<< " doesn't use value " << *A << "\n");
1292+
return nullptr;
12841293
}
1285-
if (!FoundUserInst)
1294+
1295+
if (!UserC) {
1296+
LLVM_DEBUG(
1297+
dbgs()
1298+
<< "SeparateConstOffsetFromGEP: Found XOR doesn't have constant operand"
1299+
<< *UserInst << "\n");
12861300
return nullptr;
1301+
}
12871302

1288-
// Calculate C2.
1289-
APInt C2_APInt = C0_APInt ^ C1_APInt;
1303+
if (!DT->dominates(UserInst, &I)) {
1304+
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Found XOR" << *UserInst
1305+
<< " doesn't dominate " << I << "\n");
1306+
return nullptr;
1307+
}
1308+
1309+
FoundUserInst = UserInst;
1310+
C0_APInt = UserC->getValue();
1311+
1312+
// Calculate C2 = C1 - C0.
1313+
APInt C2_APInt = C1_APInt - C0_APInt;
12901314

12911315
// Check Disjointness A & C2 == 0.
12921316
KnownBits KnownA(BitWidth);
1293-
AssumptionCache *AC = nullptr;
1294-
computeKnownBits(A, KnownA, *DL, 0, AC, &I, DT);
1317+
computeKnownBits(A, KnownA, *DL, 0, nullptr, &I, DT);
12951318

1296-
if ((KnownA.Zero & C2_APInt) != C2_APInt)
1319+
if ((KnownA.One & C2_APInt) != 0) {
1320+
LLVM_DEBUG(
1321+
dbgs() << "SeparateConstOffsetFromGEP: Disjointness check failed for"
1322+
<< I << "\n");
12971323
return nullptr;
1324+
}
12981325

12991326
IRBuilder<> Builder(&I);
1300-
Builder.SetInsertPoint(&I); // Access Builder directly
1327+
Builder.SetInsertPoint(&I);
13011328
Constant *C2_Const = ConstantInt::get(Ty, C2_APInt);
1302-
Twine Name = I.getName(); // Create Twine explicitly
1329+
Twine Name = I.getName();
13031330
Value *NewOr = BinaryOperator::CreateDisjointOr(FoundUserInst, C2_Const, Name,
13041331
I.getIterator());
1305-
// Transformation Conditions Met.
1306-
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Replacing " << I
1307-
<< " (used by GEP) with " << *NewOr << " based on "
1308-
<< *FoundUserInst << "\n");
1309-
1310-
#if 0
13111332
// Preserve metadata
1312-
if (Instruction *NewOrInst = dyn_cast<Instruction>(NewOr)) {
1333+
if (Instruction *NewOrInst = dyn_cast<Instruction>(NewOr))
13131334
NewOrInst->copyMetadata(I);
1314-
} else {
1315-
assert(false && "CreateNUWOr did not return an Instruction");
1316-
if (NewOr)
1317-
NewOr->deleteValue();
1318-
return nullptr;
1319-
}
1320-
#endif
13211335

1322-
// Return the replacement value. runOnFunction will handle replacement &
1323-
// deletion.
1336+
LLVM_DEBUG(dbgs() << "SeparateConstOffsetFromGEP: Replacing" << I
1337+
<< " (used by GEP) with" << *NewOr << " based on"
1338+
<< *FoundUserInst << "\n");
1339+
13241340
return NewOr;
13251341
}
13261342

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
3+
; RUN: -S < %s | FileCheck %s
4+
5+
6+
; Test with GEP user and known bits: Ensure the transformation occurs when the xor has a GEP user
7+
define ptr @test_with_gep_user(ptr %ptr) {
8+
; CHECK-LABEL: define ptr @test_with_gep_user(
9+
; CHECK-SAME: ptr [[PTR:%.*]]) {
10+
; CHECK-NEXT: [[ENTRY:.*:]]
11+
; CHECK-NEXT: [[BASE:%.*]] = add i64 0, 0
12+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[BASE]], 8
13+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], 16
14+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
15+
; CHECK-NEXT: ret ptr [[GEP]]
16+
;
17+
entry:
18+
%base = add i64 0,0
19+
%xor1 = xor i64 %base, 8
20+
%xor2 = xor i64 %base, 24 ; Should be replaced with OR of %xor1 and 16
21+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
22+
ret ptr %gep
23+
}
24+
25+
26+
; Test with non-GEP user: Ensure the transformation does not occur
27+
define i32 @test_with_non_gep_user(ptr %ptr) {
28+
; CHECK-LABEL: define i32 @test_with_non_gep_user(
29+
; CHECK-SAME: ptr [[PTR:%.*]]) {
30+
; CHECK-NEXT: [[ENTRY:.*:]]
31+
; CHECK-NEXT: [[BASE:%.*]] = add i32 0, 0
32+
; CHECK-NEXT: [[XOR1:%.*]] = xor i32 [[BASE]], 8
33+
; CHECK-NEXT: [[XOR2:%.*]] = xor i32 [[BASE]], 24
34+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[XOR2]], 5
35+
; CHECK-NEXT: ret i32 [[ADD]]
36+
;
37+
entry:
38+
%base = add i32 0,0
39+
%xor1 = xor i32 %base, 8
40+
%xor2 = xor i32 %base, 24
41+
%add = add i32 %xor2, 5
42+
ret i32 %add
43+
}
44+
45+
; Test with non-constant operand: Ensure the transformation does not occur
46+
define ptr @test_with_non_constant_operand(i64 %val, i64 %val2, ptr %ptr) {
47+
; CHECK-LABEL: define ptr @test_with_non_constant_operand(
48+
; CHECK-SAME: i64 [[VAL:%.*]], i64 [[VAL2:%.*]], ptr [[PTR:%.*]]) {
49+
; CHECK-NEXT: [[ENTRY:.*:]]
50+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[VAL]], [[VAL2]]
51+
; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[VAL]], 24
52+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR2]]
53+
; CHECK-NEXT: ret ptr [[GEP]]
54+
;
55+
entry:
56+
%xor1 = xor i64 %val, %val2 ; Non-constant operand
57+
%xor2 = xor i64 %val, 24
58+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
59+
ret ptr %gep
60+
}
61+
62+
; Test with unknown disjoint bits: Ensure the transformation does not occur
63+
define ptr @test_with_unknown_disjoint_bits(i64 %base, ptr %ptr) {
64+
; CHECK-LABEL: define ptr @test_with_unknown_disjoint_bits(
65+
; CHECK-SAME: i64 [[BASE:%.*]], ptr [[PTR:%.*]]) {
66+
; CHECK-NEXT: [[ENTRY:.*:]]
67+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[BASE]], 8
68+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], 16
69+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
70+
; CHECK-NEXT: ret ptr [[GEP]]
71+
;
72+
entry:
73+
%xor1 = xor i64 %base, 8
74+
%xor2 = xor i64 %base, 24
75+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
76+
ret ptr %gep
77+
}
78+
79+
; Test with non-disjoint bits: Ensure the transformation does not occur
80+
define ptr @test_with_non_disjoint_bits(i64 %val, ptr %ptr) {
81+
; CHECK-LABEL: define ptr @test_with_non_disjoint_bits(
82+
; CHECK-SAME: i64 [[VAL:%.*]], ptr [[PTR:%.*]]) {
83+
; CHECK-NEXT: [[ENTRY:.*:]]
84+
; CHECK-NEXT: [[AND:%.*]] = and i64 [[VAL]], 31
85+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[AND]], 4
86+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], 16
87+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
88+
; CHECK-NEXT: ret ptr [[GEP]]
89+
;
90+
entry:
91+
%and = and i64 %val, 31 ; val can have bits 0-4 set
92+
%xor1 = xor i64 %and, 4 ; Flips bit 2
93+
%xor2 = xor i64 %and, 20 ; Flips bits 2 and 4, should NOT replace since bit 4 overlaps with possible val bits
94+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
95+
ret ptr %gep
96+
}
97+
98+
; Test with multiple xor operations in sequence
99+
define ptr @test_multiple_xors(ptr %ptr) {
100+
; CHECK-LABEL: define ptr @test_multiple_xors(
101+
; CHECK-SAME: ptr [[PTR:%.*]]) {
102+
; CHECK-NEXT: [[ENTRY:.*:]]
103+
; CHECK-NEXT: [[BASE:%.*]] = add i64 2, 0
104+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[BASE]], 8
105+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], 16
106+
; CHECK-NEXT: [[XOR32:%.*]] = or disjoint i64 [[XOR1]], 24
107+
; CHECK-NEXT: [[XOR43:%.*]] = or disjoint i64 [[XOR1]], 64
108+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
109+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR32]]
110+
; CHECK-NEXT: [[GEP4:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR43]]
111+
; CHECK-NEXT: ret ptr [[GEP4]]
112+
;
113+
entry:
114+
%base = add i64 2,0
115+
%xor1 = xor i64 %base, 8
116+
%xor2 = xor i64 %base, 24 ; Should be replaced with OR
117+
%xor3 = xor i64 %base, 32
118+
%xor4 = xor i64 %base, 72 ; Should be replaced with OR
119+
%gep2 = getelementptr i8, ptr %ptr, i64 %xor2
120+
%gep3 = getelementptr i8, ptr %ptr, i64 %xor3
121+
%gep4 = getelementptr i8, ptr %ptr, i64 %xor4
122+
ret ptr %gep4
123+
}
124+
125+
126+
; Test with operand order variations
127+
define ptr @test_operand_order(ptr %ptr) {
128+
; CHECK-LABEL: define ptr @test_operand_order(
129+
; CHECK-SAME: ptr [[PTR:%.*]]) {
130+
; CHECK-NEXT: [[ENTRY:.*:]]
131+
; CHECK-NEXT: [[BASE:%.*]] = add i64 2, 0
132+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[BASE]], 12
133+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], 12
134+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
135+
; CHECK-NEXT: ret ptr [[GEP]]
136+
;
137+
entry:
138+
%base = add i64 2,0
139+
%xor1 = xor i64 %base, 12
140+
%xor2 = xor i64 24, %base ; Operands reversed, should still be replaced
141+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
142+
ret ptr %gep
143+
}
144+
145+
146+
; Test with multiple xor operations in sequence
147+
define ptr @aatest_multiple_xors(ptr %ptr) {
148+
; CHECK-LABEL: define ptr @aatest_multiple_xors(
149+
; CHECK-SAME: ptr [[PTR:%.*]]) {
150+
; CHECK-NEXT: [[ENTRY:.*:]]
151+
; CHECK-NEXT: [[BASE:%.*]] = add i64 2, 0
152+
; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[BASE]], 72
153+
; CHECK-NEXT: [[XOR21:%.*]] = or disjoint i64 [[XOR1]], -48
154+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[XOR21]]
155+
; CHECK-NEXT: ret ptr [[GEP]]
156+
;
157+
entry:
158+
%base = add i64 2,0
159+
%xor1 = xor i64 %base, 72
160+
%xor2 = xor i64 %base, 24 ; Should be replaced with OR
161+
%gep = getelementptr i8, ptr %ptr, i64 %xor2
162+
ret ptr %gep
163+
}

0 commit comments

Comments
 (0)