Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/Transforms/Scalar/GVN.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Analysis/LoopInfo.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can forward declare Loop instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#include "llvm/IR/ValueHandle.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/Compiler.h"
Expand All @@ -45,6 +46,8 @@ class FunctionPass;
class GetElementPtrInst;
class ImplicitControlFlowTracking;
class LoadInst;
class Loop;
class SelectInst;
class LoopInfo;
class MemDepResult;
class MemoryAccess;
Expand Down Expand Up @@ -405,6 +408,14 @@ class GVNPass : public PassInfoMixin<GVNPass> {
void addDeadBlock(BasicBlock *BB);
void assignValNumForDeadCode();
void assignBlockRPONumber(Function &F);

bool recognizeMinFindingSelectPattern(SelectInst *Select);
bool transformMinFindingSelectPattern(Loop *L, Type *LoadType,
BasicBlock *Preheader, BasicBlock *BB,
Value *LHS, Value *RHS,
CmpInst *Comparison, SelectInst *Select,
Value *BasePtr, Value *IndexVal,
Value *OffsetVal);
};

/// Create a legacy GVN pass.
Expand Down
234 changes: 233 additions & 1 deletion llvm/lib/Transforms/Scalar/GVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include "llvm/Analysis/InstructionPrecedenceTracking.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/Analysis/MemorySSA.h"
Expand Down Expand Up @@ -2743,6 +2742,10 @@ bool GVNPass::processInstruction(Instruction *I) {
}
return Changed;
}
if (SelectInst *Select = dyn_cast<SelectInst>(I)) {
if (recognizeMinFindingSelectPattern(Select))
return true;
}

// Instructions with void type don't return a value, so there's
// no point in trying to find redundancies in them.
Expand Down Expand Up @@ -3330,6 +3333,235 @@ void GVNPass::assignValNumForDeadCode() {
}
}

// Hoist the chain of operations for the second load to preheader.
// In this transformation, we hoist the redundant load to the preheader,
// caching the first value of the iteration. This value is used to compare with
// the current value of the iteration and update the minimum value.
// The comparison is done in the loop body using the new select instruction.
//
// *** Before transformation ***
//
// preheader:
// ...
// loop:
// ...
// ...
// %val.first = load <TYPE>, ptr %ptr.first.load, align 4
// %min.idx.ext = sext i32 %min.idx to i64
// %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
// %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4
// %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4
// ...
// ...
// br i1 %cond, label %loop, label %exit
//
// *** After transformation ***
//
// preheader:
// %min.idx.ext = sext i32 %min.idx.ext to i64
// %hoist_gep1 = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
// %hoist_gep2 = getelementptr i8, ptr %hoist_gep1, i64 -4
// %hoisted_load = load <TYPE>, ptr %hoist_gep2, align 4
// br label %loop
//
// loop:
// %val.first = load <TYPE>, ptr %ptr.first.load, align 4
// ...
// (new) %val.current.min = select i1 %cond, <TYPE> %hoisted_load, <TYPE>
// %val.current.min
// ...
// ...
// br i1 %cond, label %loop, label %exit
bool GVNPass::transformMinFindingSelectPattern(
Loop *L, Type *LoadType, BasicBlock *Preheader, BasicBlock *BB, Value *LHS,
Value *LoadVal, CmpInst *Comparison, SelectInst *Select, Value *BasePtr,
Value *IndexVal, Value *OffsetVal) {

assert(IndexVal && "IndexVal is null");
AAResults *AA = VN.getAliasAnalysis();
assert(AA && "AA is null");

IRBuilder<> Builder(Preheader->getTerminator());
Value *InitialMinIndex =
dyn_cast<PHINode>(IndexVal)->getIncomingValueForBlock(Preheader);

// Insert PHI node at the top of this block.
// This PHI node will be used to memoize the current minimum value so far.
PHINode *KnownMinPhi = PHINode::Create(LoadType, 2, "known_min", BB->begin());

// Hoist the load and build the necessary operations.
// 1. hoist_0 = sext i32 1 to i64
Value *HoistedSExt =
Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sext i64 seems to make assumptions about the types that are involved without ever checking them. I expect assertion failures if you use different types. Ideally this should be type independent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added checks.


// 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt
Value *HoistedGEP1 =
Builder.CreateGEP(LoadType, BasePtr, HoistedSExt, "hoist_gep1");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to make an assumption that the load type and the GEP type are the same, which is not necessarily true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm..Capturing GEP in m_match is bit tricky here. How do I add such check?


// 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal
Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1,
OffsetVal, "hoist_gep2");

MemoryLocation NewLoc = MemoryLocation(
HoistedGEP2,
LocationSize::precise(
L->getHeader()->getDataLayout().getTypeStoreSize(LoadType)));
// Check if any instruction in the loop clobbers this location.
bool CanHoist = true;
for (BasicBlock *BB : L->blocks()) {
for (Instruction &I : *BB) {
if (I.mayWriteToMemory()) {
// Check if this instruction may clobber our hoisted load.
ModRefInfo MRI = AA->getModRefInfo(&I, NewLoc);
if (isModOrRefSet(MRI)) {
LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - may be clobbered by: " << I
<< "\n");
CanHoist = false;
break;
}
}
}
if (!CanHoist)
break;
}
if (!CanHoist) {
LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - may be clobbered by some "
"instruction in the loop.\n");
return false;
}

// 4. hoisted_load = load float, ptr HoistedGEP2
LoadInst *NewLoad = Builder.CreateLoad(LoadType, HoistedGEP2, "hoisted_load");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to hoist loads, shouldn't we be querying MDA about clobbers somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we probably also need to update MSSA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added the loop for checking clobbering.


// Let the new load now take the place of the old load.
LoadVal->replaceAllUsesWith(NewLoad);
dyn_cast<LoadInst>(LoadVal)->eraseFromParent();

// Comparison should now compare the current value and the newly inserted
// PHI node.
Comparison->setOperand(1, KnownMinPhi);

// Create new select instruction for selecting the minimum value.
IRBuilder<> SelectBuilder(BB->getTerminator());
SelectInst *CurrentMinSelect = dyn_cast<SelectInst>(
SelectBuilder.CreateSelect(Comparison, LHS, KnownMinPhi, "current_min"));

// Populate the newly created PHI node
// with (hoisted) NewLoad from the preheader and CurrentMinSelect.
KnownMinPhi->addIncoming(NewLoad, Preheader);
KnownMinPhi->addIncoming(CurrentMinSelect, BB);

if (MSSAU) {
auto *OrigUse =
MSSAU->getMemorySSA()->getMemoryAccess(dyn_cast<Instruction>(LoadVal));
if (OrigUse) {
MemoryAccess *DefiningAccess = OrigUse->getDefiningAccess();
MSSAU->createMemoryAccessInBB(NewLoad, DefiningAccess, Preheader,
MemorySSA::BeforeTerminator);
}
}
LLVM_DEBUG(
dbgs() << "GVN: Transformed the code for minimum finding pattern.\n");
return true;
}

// We are looking for the following pattern:
// loop:
// ...
// ...
// %min.idx = phi i32 [ %initial_min_idx, %entry ], [ %min.idx.next, %loop ]
// ...
// %val.first = load <TYPE>, ptr %ptr.first.load, align 4
// %min.idx.ext = sext i32 %min.idx to i64
// %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
// %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4
// %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4
// %cmp = <CMP_INST> <TYPE> %val.first, %val.current.min
// ...
// %min.idx.next = select i1 %cmp, ..., i32 %min.idx
// ...
// ...
// br i1 ..., label %loop, ...
bool GVNPass::recognizeMinFindingSelectPattern(SelectInst *Select) {
IRBuilder<> Builder(Select);
Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr,
*SExt = nullptr;
BasicBlock *BB = Select->getParent();

// If the block is not in a loop, bail out.
Loop *L = LI->getLoopFor(BB);
if (!L) {
LLVM_DEBUG(dbgs() << "GVN: Could not find loop.\n");
return false;
}

// If preheader of the loop is not found, bail out.
BasicBlock *Preheader = L->getLoopPreheader();
if (!Preheader) {
LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader.\n");
return false;
}
Value *Condition = Select->getCondition();
CmpInst *Comparison = dyn_cast<CmpInst>(Condition);
if (!Comparison) {
LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison.\n");
return false;
}

// Check if this is less-than comparison.
CmpInst::Predicate Pred = Comparison->getPredicate();
if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT &&
Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that you will fail to handle the pattern where x > y ? b : a is used instead of x < y ? a : b. It's better to use swapped predicate in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, if I got this. Do you mean I should swap, get Swapped predicate and then go ahead?

LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred
<< "\n");
return false;
}

// Check that both operands are loads.
Value *LHS = Comparison->getOperand(0);
Value *RHS = Comparison->getOperand(1);
if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) {
LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads.\n");
return false;
}

if (!match(RHS, m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_Value(SExt)),
m_Value(OffsetVal))))) {
LLVM_DEBUG(dbgs() << "GVN: Not a required load pattern.\n");
return false;
}
// Check if the SExt instruction is a sext instruction.
SExtInst *SEInst = dyn_cast<SExtInst>(SExt);
if (!SEInst) {
LLVM_DEBUG(dbgs() << "GVN: not a sext instruction.\n");
return false;
}
// Check if the "To" and "from" type of the sext instruction are i64 and i32
// respectively.
if (SEInst->getType() != Builder.getInt64Ty() ||
SEInst->getOperand(0)->getType() != Builder.getInt32Ty()) {
LLVM_DEBUG(
dbgs()
<< "GVN: Not matching the required type for sext instruction.\n");
return false;
}

IndexVal = SEInst->getOperand(0);
// Check if the IndexVal is a PHI node.
PHINode *Phi = dyn_cast<PHINode>(IndexVal);
if (!Phi) {
LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n");
return false;
}

LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: "
<< Select->getParent()->getName() << ".\n");

return transformMinFindingSelectPattern(L, dyn_cast<LoadInst>(LHS)->getType(),
Preheader, BB, LHS, RHS, Comparison,
Select, BasePtr, IndexVal, OffsetVal);
}

class llvm::gvn::GVNLegacyPass : public FunctionPass {
public:
static char ID; // Pass identification, replacement for typeid.
Expand Down
Loading
Loading