Skip to content

Commit 6ce1b5f

Browse files
committed
!fixup introduce SafeUDivMode to SCEVExpander.
1 parent bb5f8e9 commit 6ce1b5f

File tree

9 files changed

+144
-70
lines changed

9 files changed

+144
-70
lines changed

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
124124
/// "expanded" form.
125125
bool LSRMode;
126126

127+
/// When true, rewrite any divisors of UDiv expressions that may be 0 to
128+
/// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
129+
/// freeze it first.
130+
bool SafeUDivMode = false;
131+
127132
typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
128133
BuilderType Builder;
129134

@@ -300,6 +305,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
300305
/// location and their operands are defined at this location.
301306
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const;
302307

308+
static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
309+
ScalarEvolution &SE);
310+
303311
/// Insert code to directly compute the specified SCEV expression into the
304312
/// program. The code is inserted into the specified block.
305313
Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
@@ -418,13 +426,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
418426
BasicBlock::iterator findInsertPointAfter(Instruction *I,
419427
Instruction *MustDominate) const;
420428

421-
/// If \p L contains exits which may execute conditionally and contain UDiv
422-
/// expressions with divisors that can be 0, expanding \p BTC may introduce
423-
/// new UB. In that case, rewrite UDiv(A, B) to UDiv(A, UMAX(1, B)). If B is
424-
/// 0, that exit cannot be taken.
425429
static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
426430
ScalarEvolution &SE);
427431

432+
void setSafeUDivMode() { SafeUDivMode = true; }
433+
428434
private:
429435
LLVMContext &getContext() const { return SE.getContext(); }
430436

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,12 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
665665
}
666666

667667
Value *RHS = expand(S->getRHS());
668+
if (SafeUDivMode && !SE.isKnownNonZero(S->getRHS())) {
669+
if (!isGuaranteedNotToBePoison(RHS))
670+
RHS = Builder.CreateFreeze(RHS);
671+
RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
672+
{RHS, ConstantInt::get(RHS->getType(), 1)});
673+
}
668674
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
669675
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
670676
}
@@ -2341,15 +2347,15 @@ struct SCEVFindUnsafe {
23412347
};
23422348
} // namespace
23432349

2344-
static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
2345-
ScalarEvolution &SE) {
2350+
bool SCEVExpander::isSafeToExpand(const SCEV *S, bool CanonicalMode,
2351+
ScalarEvolution &SE) {
23462352
SCEVFindUnsafe Search(SE, CanonicalMode);
23472353
visitAll(S, Search);
23482354
return !Search.IsUnsafe;
23492355
}
23502356

23512357
bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
2352-
return ::isSafeToExpand(S, CanonicalMode, SE);
2358+
return isSafeToExpand(S, CanonicalMode, SE);
23532359
}
23542360

23552361
bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
@@ -2374,47 +2380,6 @@ bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
23742380
return false;
23752381
}
23762382

2377-
namespace {
2378-
struct SCEVUDivRewriter : public SCEVRewriteVisitor<SCEVUDivRewriter> {
2379-
SCEVUDivRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
2380-
2381-
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
2382-
SCEVUDivRewriter Rewriter(SE);
2383-
return Rewriter.visit(Scev);
2384-
}
2385-
2386-
const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
2387-
auto *LHS = visit(Expr->getLHS());
2388-
auto *RHS = visit(Expr->getRHS());
2389-
return SE.getUDivExpr(LHS, SE.getUMaxExpr(SE.getOne(RHS->getType()), RHS));
2390-
}
2391-
};
2392-
} // namespace
2393-
2394-
const SCEV *SCEVExpander::rewriteExpressionToRemoveUB(const SCEV *Expr, Loop *L,
2395-
ScalarEvolution &SE) {
2396-
SmallVector<BasicBlock *> Exiting;
2397-
L->getExitingBlocks(Exiting);
2398-
// Check if exit count for any exit that may execute unconditionally may in
2399-
// introduce UB. Note that we can skip checks in the header or if there's a
2400-
// single exit, as in those cases we know that the exit count will be
2401-
// evaluated in each loop iteration. There are other cases where the exiting
2402-
// block executes on each loop iteration, but we don't have a cheap way to
2403-
// check at the moment.
2404-
2405-
if (Exiting.size() == 1 || all_of(Exiting, [L, &SE](BasicBlock *E) {
2406-
if (L->getHeader() == E)
2407-
return true;
2408-
const SCEV *EC = SE.getExitCount(L, E);
2409-
if (isa<SCEVCouldNotCompute>(EC))
2410-
return true;
2411-
return ::isSafeToExpand(EC, true, SE);
2412-
}))
2413-
return Expr;
2414-
2415-
return SCEVUDivRewriter::rewrite(Expr, SE);
2416-
}
2417-
24182383
void SCEVExpanderCleaner::cleanup() {
24192384
// Result is used, nothing to remove.
24202385
if (ResultUsed)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -916,9 +916,6 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
916916
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
917917

918918
ScalarEvolution &SE = *PSE.getSE();
919-
if (OrigLoop)
920-
BackedgeTakenCount = SCEVExpander::rewriteExpressionToRemoveUB(
921-
BackedgeTakenCount, OrigLoop, SE);
922919
return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
923920
}
924921

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,8 +866,31 @@ VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE,
866866
VPIRBasicBlock *Entry = new VPIRBasicBlock(TheLoop->getLoopPreheader());
867867
VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph");
868868
auto Plan = std::make_unique<VPlan>(Entry, VecPreheader);
869-
Plan->TripCount =
870-
vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE);
869+
870+
bool NeedsSafeUDivMode = false;
871+
{
872+
SmallVector<BasicBlock *> Exiting;
873+
TheLoop->getExitingBlocks(Exiting);
874+
875+
// Check if exit count for any exit that may execute unconditionally may in
876+
// introduce UB. Note that we can skip checks in the header or if there's a
877+
// single exit, as in those cases we know that the exit count will be
878+
// evaluated in each loop iteration. There are other cases where the exiting
879+
// block executes on each loop iteration, but we don't have a cheap way to
880+
// check at the moment.
881+
NeedsSafeUDivMode =
882+
Exiting.size() != 1 && any_of(Exiting, [TheLoop, &SE](BasicBlock *E) {
883+
if (TheLoop->getHeader() == E)
884+
return false;
885+
const SCEV *EC = SE.getExitCount(TheLoop, E);
886+
if (isa<SCEVCouldNotCompute>(EC))
887+
return false;
888+
return !SCEVExpander::isSafeToExpand(EC, true, SE);
889+
});
890+
}
891+
892+
Plan->TripCount = vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE,
893+
NeedsSafeUDivMode);
871894
// Create VPRegionBlock, with empty header and latch blocks, to be filled
872895
// during processing later.
873896
VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2771,9 +2771,15 @@ class VPExpandSCEVRecipe : public VPSingleDefRecipe {
27712771
const SCEV *Expr;
27722772
ScalarEvolution &SE;
27732773

2774+
/// When set to true, set SafeUDivMode when expanding the SCEV to avoid
2775+
/// introducing UB.
2776+
bool SafeUDivMode;
2777+
27742778
public:
2775-
VPExpandSCEVRecipe(const SCEV *Expr, ScalarEvolution &SE)
2776-
: VPSingleDefRecipe(VPDef::VPExpandSCEVSC, {}), Expr(Expr), SE(SE) {}
2779+
VPExpandSCEVRecipe(const SCEV *Expr, ScalarEvolution &SE,
2780+
bool SafeUDivMode = false)
2781+
: VPSingleDefRecipe(VPDef::VPExpandSCEVSC, {}), Expr(Expr), SE(SE),
2782+
SafeUDivMode(SafeUDivMode) {}
27772783

27782784
~VPExpandSCEVRecipe() override = default;
27792785

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,6 +2932,8 @@ void VPExpandSCEVRecipe::execute(VPTransformState &State) {
29322932
assert(!State.Instance && "cannot be used in per-lane");
29332933
const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
29342934
SCEVExpander Exp(SE, DL, "induction");
2935+
if (SafeUDivMode)
2936+
Exp.setSafeUDivMode();
29352937

29362938
Value *Res = Exp.expandCodeFor(Expr, Expr->getType(),
29372939
&*State.Builder.GetInsertPoint());

llvm/lib/Transforms/Vectorize/VPlanUtils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ bool vputils::onlyFirstPartUsed(const VPValue *Def) {
2323
}
2424

2525
VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
26-
ScalarEvolution &SE) {
26+
ScalarEvolution &SE,
27+
bool SafeUDivMode) {
2728
if (auto *Expanded = Plan.getSCEVExpansion(Expr))
2829
return Expanded;
2930
VPValue *Expanded = nullptr;
@@ -32,7 +33,7 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
3233
else if (auto *E = dyn_cast<SCEVUnknown>(Expr))
3334
Expanded = Plan.getOrAddLiveIn(E->getValue());
3435
else {
35-
Expanded = new VPExpandSCEVRecipe(Expr, SE);
36+
Expanded = new VPExpandSCEVRecipe(Expr, SE, SafeUDivMode);
3637
Plan.getPreheader()->appendRecipe(Expanded->getDefiningRecipe());
3738
}
3839
Plan.addSCEVExpansion(Expr, Expanded);

llvm/lib/Transforms/Vectorize/VPlanUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ bool onlyFirstPartUsed(const VPValue *Def);
2424
/// pre-header already contains a recipe expanding \p Expr, return it. If not,
2525
/// create a new one.
2626
VPValue *getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
27-
ScalarEvolution &SE);
27+
ScalarEvolution &SE,
28+
bool SaveUDivMode = false);
2829

2930
/// Returns true if \p VPV is uniform after vectorization.
3031
inline bool isUniformAfterVectorization(const VPValue *VPV) {

0 commit comments

Comments
 (0)