Skip to content

Commit 9f3a748

Browse files
committed
[InstCombine] Support GEP chains in foldCmpLoadFromIndexedGlobal()
Currently this fold only supports a single GEP. However, in ptradd representation, it may be split across multiple GEPs. In particular, PR #151333 will split off constant offset GEPs. To support this, add a new helper decomposeLinearExpression(), which decomposes a pointer into a linear expression of the form BasePtr + Index * Scale + Offset. I plan to also extend this helper to look through mul/shl on the index and use it in more places that currently use collectOffset() to extract a single index * scale. This will make sure such optimizations are not affected by the ptradd migration.
1 parent 305cf0e commit 9f3a748

File tree

5 files changed

+168
-29
lines changed

5 files changed

+168
-29
lines changed

llvm/include/llvm/Analysis/Loads.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#ifndef LLVM_ANALYSIS_LOADS_H
1414
#define LLVM_ANALYSIS_LOADS_H
1515

16+
#include "llvm/ADT/APInt.h"
1617
#include "llvm/IR/BasicBlock.h"
18+
#include "llvm/IR/GEPNoWrapFlags.h"
1719
#include "llvm/Support/CommandLine.h"
1820
#include "llvm/Support/Compiler.h"
1921

@@ -193,6 +195,26 @@ LLVM_ABI bool canReplacePointersIfEqual(const Value *From, const Value *To,
193195
const DataLayout &DL);
194196
LLVM_ABI bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
195197
const DataLayout &DL);
198+
199+
/// Linear expression BasePtr + Index * Scale + Offset.
200+
/// Index, Scale and Offset all have the same bit width, which matches the
201+
/// pointer index size of BasePtr.
202+
/// Index may be nullptr if Scale is 0.
203+
struct LinearExpression {
204+
Value *BasePtr;
205+
Value *Index = nullptr;
206+
APInt Scale;
207+
APInt Offset;
208+
GEPNoWrapFlags Flags = GEPNoWrapFlags::all();
209+
210+
LinearExpression(Value *BasePtr, unsigned BitWidth)
211+
: BasePtr(BasePtr), Scale(BitWidth, 0), Offset(BitWidth, 0) {}
212+
};
213+
214+
/// Decompose a pointer into a linear expression. This may look through
215+
/// multiple GEPs.
216+
LLVM_ABI LinearExpression decomposeLinearExpression(const DataLayout &DL,
217+
Value *Ptr);
196218
}
197219

198220
#endif

llvm/lib/Analysis/Loads.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
2222
#include "llvm/Analysis/ValueTracking.h"
2323
#include "llvm/IR/DataLayout.h"
24+
#include "llvm/IR/GetElementPtrTypeIterator.h"
2425
#include "llvm/IR/IntrinsicInst.h"
2526
#include "llvm/IR/Operator.h"
2627

@@ -876,3 +877,66 @@ bool llvm::isReadOnlyLoop(
876877
}
877878
return true;
878879
}
880+
881+
LinearExpression llvm::decomposeLinearExpression(const DataLayout &DL,
882+
Value *Ptr) {
883+
assert(Ptr->getType()->isPointerTy() && "Must be called with pointer arg");
884+
885+
unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType());
886+
LinearExpression Expr(Ptr, BitWidth);
887+
888+
while (true) {
889+
auto *GEP = dyn_cast<GEPOperator>(Expr.BasePtr);
890+
if (!GEP || GEP->getSourceElementType()->isScalableTy())
891+
return Expr;
892+
893+
Value *VarIndex = nullptr;
894+
for (Value *Index : GEP->indices()) {
895+
if (isa<ConstantInt>(Index))
896+
continue;
897+
// Only allow a single variable index. We do not bother to handle the
898+
// case of the same variable index appearing multiple times.
899+
if (Expr.Index || VarIndex)
900+
return Expr;
901+
VarIndex = Index;
902+
}
903+
904+
// Don't return non-canonical indexes.
905+
if (VarIndex && !VarIndex->getType()->isIntegerTy(BitWidth))
906+
return Expr;
907+
908+
// We have verified that we can fully handle this GEP, so we can update Expr
909+
// members past this point.
910+
Expr.BasePtr = GEP->getPointerOperand();
911+
Expr.Flags = Expr.Flags.intersectForOffsetAdd(GEP->getNoWrapFlags());
912+
for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
913+
GTI != GTE; ++GTI) {
914+
Value *Index = GTI.getOperand();
915+
if (auto *ConstOffset = dyn_cast<ConstantInt>(Index)) {
916+
if (ConstOffset->isZero())
917+
continue;
918+
if (StructType *STy = GTI.getStructTypeOrNull()) {
919+
unsigned ElementIdx = ConstOffset->getZExtValue();
920+
const StructLayout *SL = DL.getStructLayout(STy);
921+
Expr.Offset += SL->getElementOffset(ElementIdx);
922+
continue;
923+
}
924+
// Truncate if type size exceeds index space.
925+
APInt IndexedSize(BitWidth, GTI.getSequentialElementStride(DL),
926+
/*isSigned=*/false,
927+
/*implcitTrunc=*/true);
928+
Expr.Offset += ConstOffset->getValue() * IndexedSize;
929+
continue;
930+
}
931+
932+
// FIXME: Also look through a mul/shl in the index.
933+
assert(Expr.Index == nullptr && "Shouldn't have index yet");
934+
Expr.Index = Index;
935+
// Truncate if type size exceeds index space.
936+
Expr.Scale = APInt(BitWidth, GTI.getSequentialElementStride(DL),
937+
/*isSigned=*/false, /*implicitTrunc=*/true);
938+
}
939+
}
940+
941+
return Expr;
942+
}

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/Analysis/CmpInstAnalysis.h"
2020
#include "llvm/Analysis/ConstantFolding.h"
2121
#include "llvm/Analysis/InstructionSimplify.h"
22+
#include "llvm/Analysis/Loads.h"
2223
#include "llvm/Analysis/Utils/Local.h"
2324
#include "llvm/Analysis/VectorUtils.h"
2425
#include "llvm/IR/ConstantRange.h"
@@ -110,30 +111,27 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
110111
/// If AndCst is non-null, then the loaded value is masked with that constant
111112
/// before doing the comparison. This handles cases like "A[i]&4 == 0".
112113
Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
113-
LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI,
114-
ConstantInt *AndCst) {
115-
if (LI->isVolatile() || !GV->isConstant() || !GV->hasDefinitiveInitializer())
114+
LoadInst *LI, GetElementPtrInst *GEP, CmpInst &ICI, ConstantInt *AndCst) {
115+
auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(GEP));
116+
if (LI->isVolatile() || !GV || !GV->isConstant() ||
117+
!GV->hasDefinitiveInitializer())
116118
return nullptr;
117119

118-
Constant *Init = GV->getInitializer();
119-
TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
120120
Type *EltTy = LI->getType();
121121
TypeSize EltSize = DL.getTypeStoreSize(EltTy);
122122
if (EltSize.isScalable())
123123
return nullptr;
124124

125-
unsigned IndexBW = DL.getIndexTypeSizeInBits(GEP->getType());
126-
SmallMapVector<Value *, APInt, 4> VarOffsets;
127-
APInt ConstOffset(IndexBW, 0);
128-
if (!GEP->collectOffset(DL, IndexBW, VarOffsets, ConstOffset) ||
129-
VarOffsets.size() != 1 || IndexBW > 64)
125+
LinearExpression Expr = decomposeLinearExpression(DL, GEP);
126+
if (!Expr.Index || Expr.BasePtr != GV || Expr.Offset.getBitWidth() > 64)
130127
return nullptr;
131128

132-
Value *Idx = VarOffsets.front().first;
133-
const APInt &Stride = VarOffsets.front().second;
134-
// If the index type is non-canonical, wait for it to be canonicalized.
135-
if (Idx->getType()->getScalarSizeInBits() != IndexBW)
136-
return nullptr;
129+
Constant *Init = GV->getInitializer();
130+
TypeSize GlobalSize = DL.getTypeAllocSize(Init->getType());
131+
132+
Value *Idx = Expr.Index;
133+
const APInt &Stride = Expr.Scale;
134+
const APInt &ConstOffset = Expr.Offset;
137135

138136
// Allow an additional context offset, but only within the stride.
139137
if (!ConstOffset.ult(Stride))
@@ -280,7 +278,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
280278
// comparison is false if Idx was 0x80..00.
281279
// We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
282280
auto MaskIdx = [&](Value *Idx) {
283-
if (!GEP->isInBounds() && Stride.countr_zero() != 0) {
281+
if (!Expr.Flags.isInBounds() && Stride.countr_zero() != 0) {
284282
Value *Mask = Constant::getAllOnesValue(Idx->getType());
285283
Mask = Builder.CreateLShr(Mask, Stride.countr_zero());
286284
Idx = Builder.CreateAnd(Idx, Mask);
@@ -1958,10 +1956,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
19581956
if (auto *C2 = dyn_cast<ConstantInt>(Y))
19591957
if (auto *LI = dyn_cast<LoadInst>(X))
19601958
if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)))
1961-
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
1962-
if (Instruction *Res =
1963-
foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2))
1964-
return Res;
1959+
if (Instruction *Res = foldCmpLoadFromIndexedGlobal(LI, GEP, Cmp, C2))
1960+
return Res;
19651961

19661962
if (!Cmp.isEquality())
19671963
return nullptr;
@@ -4314,10 +4310,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
43144310
// Try to optimize things like "A[i] > 4" to index computations.
43154311
if (GetElementPtrInst *GEP =
43164312
dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
4317-
if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
4318-
if (Instruction *Res =
4319-
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I))
4320-
return Res;
4313+
if (Instruction *Res =
4314+
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
4315+
return Res;
43214316
break;
43224317
}
43234318

@@ -8798,10 +8793,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
87988793
break;
87998794
case Instruction::Load:
88008795
if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
8801-
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
8802-
if (Instruction *Res = foldCmpLoadFromIndexedGlobal(
8803-
cast<LoadInst>(LHSI), GEP, GV, I))
8804-
return Res;
8796+
if (Instruction *Res =
8797+
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, I))
8798+
return Res;
88058799
break;
88068800
case Instruction::FPTrunc:
88078801
if (Instruction *NV = foldFCmpFpTrunc(I, *LHSI, *RHSC))

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
710710
bool foldAllocaCmp(AllocaInst *Alloca);
711711
Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
712712
GetElementPtrInst *GEP,
713-
GlobalVariable *GV, CmpInst &ICI,
713+
CmpInst &ICI,
714714
ConstantInt *AndCst = nullptr);
715715
Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
716716
Constant *RHSC);

llvm/test/Transforms/InstCombine/load-cmp.ll

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,32 @@ define i1 @load_vs_array_type_mismatch_offset1(i32 %idx) {
419419
ret i1 %cmp
420420
}
421421

422+
define i1 @load_vs_array_type_mismatch_offset1_separate_gep(i32 %idx) {
423+
; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep(
424+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
425+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
426+
; CHECK-NEXT: ret i1 [[CMP]]
427+
;
428+
%gep1 = getelementptr inbounds {i16, i16}, ptr @g_i16_1, i32 %idx
429+
%gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
430+
%load = load i16, ptr %gep2
431+
%cmp = icmp eq i16 %load, 0
432+
ret i1 %cmp
433+
}
434+
435+
define i1 @load_vs_array_type_mismatch_offset1_separate_gep_swapped(i32 %idx) {
436+
; CHECK-LABEL: @load_vs_array_type_mismatch_offset1_separate_gep_swapped(
437+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IDX:%.*]], -3
438+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 1
439+
; CHECK-NEXT: ret i1 [[CMP]]
440+
;
441+
%gep1 = getelementptr inbounds i8, ptr @g_i16_1, i32 2
442+
%gep2 = getelementptr inbounds {i16, i16}, ptr %gep1, i32 %idx
443+
%load = load i16, ptr %gep2
444+
%cmp = icmp eq i16 %load, 0
445+
ret i1 %cmp
446+
}
447+
422448
@g_i16_2 = internal constant [8 x i16] [i16 1, i16 0, i16 0, i16 1, i16 1, i16 0, i16 0, i16 1]
423449

424450
; idx == 0 || idx == 2
@@ -554,3 +580,36 @@ entry:
554580
%cond = icmp ult i32 %isOK, 5
555581
ret i1 %cond
556582
}
583+
584+
define i1 @cmp_load_multiple_indices(i32 %idx, i32 %idx2) {
585+
; CHECK-LABEL: @cmp_load_multiple_indices(
586+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr @g_i16_1, i32 [[IDX:%.*]]
587+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, ptr [[GEP1]], i32 [[IDX2:%.*]]
588+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP2]], i32 2
589+
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP3]], align 2
590+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
591+
; CHECK-NEXT: ret i1 [[CMP]]
592+
;
593+
%gep1 = getelementptr inbounds i16, ptr @g_i16_1, i32 %idx
594+
%gep2 = getelementptr inbounds i16, ptr %gep1, i32 %idx2
595+
%gep3 = getelementptr inbounds i8, ptr %gep2, i32 2
596+
%load = load i16, ptr %gep3
597+
%cmp = icmp eq i16 %load, 0
598+
ret i1 %cmp
599+
}
600+
601+
define i1 @cmp_load_multiple_indices2(i32 %idx, i32 %idx2) {
602+
; CHECK-LABEL: @cmp_load_multiple_indices2(
603+
; CHECK-NEXT: [[GEP1_SPLIT:%.*]] = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 [[IDX:%.*]]
604+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, ptr [[GEP1_SPLIT]], i32 [[IDX2:%.*]]
605+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds nuw i8, ptr [[GEP1]], i32 2
606+
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[GEP2]], align 2
607+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[LOAD]], 0
608+
; CHECK-NEXT: ret i1 [[CMP]]
609+
;
610+
%gep1 = getelementptr inbounds [1 x i16], ptr @g_i16_1, i32 %idx, i32 %idx2
611+
%gep2 = getelementptr inbounds i8, ptr %gep1, i32 2
612+
%load = load i16, ptr %gep2
613+
%cmp = icmp eq i16 %load, 0
614+
ret i1 %cmp
615+
}

0 commit comments

Comments
 (0)