-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[GVN] Support rnflow pattern matching and transform #162259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8826265
ea56b06
153ee85
6125286
ba3ce9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,6 @@ | |
| #include "llvm/Analysis/InstructionPrecedenceTracking.h" | ||
| #include "llvm/Analysis/InstructionSimplify.h" | ||
| #include "llvm/Analysis/Loads.h" | ||
| #include "llvm/Analysis/LoopInfo.h" | ||
| #include "llvm/Analysis/MemoryBuiltins.h" | ||
| #include "llvm/Analysis/MemoryDependenceAnalysis.h" | ||
| #include "llvm/Analysis/MemorySSA.h" | ||
|
|
@@ -2743,6 +2742,10 @@ bool GVNPass::processInstruction(Instruction *I) { | |
| } | ||
| return Changed; | ||
| } | ||
| if (SelectInst *Select = dyn_cast<SelectInst>(I)) { | ||
| if (recognizeMinFindingSelectPattern(Select)) | ||
| return true; | ||
| } | ||
|
|
||
| // Instructions with void type don't return a value, so there's | ||
| // no point in trying to find redundancies in them. | ||
|
|
@@ -3330,6 +3333,235 @@ void GVNPass::assignValNumForDeadCode() { | |
| } | ||
| } | ||
|
|
||
| // Hoist the chain of operations for the second load to preheader. | ||
| // In this transformation, we hoist the redundant load to the preheader, | ||
| // caching the first value of the iteration. This value is used to compare with | ||
| // the current value of the iteration and update the minimum value. | ||
| // The comparison is done in the loop body using the new select instruction. | ||
| // | ||
| // *** Before transformation *** | ||
| // | ||
| // preheader: | ||
| // ... | ||
| // loop: | ||
| // ... | ||
| // ... | ||
| // %val.first = load <TYPE>, ptr %ptr.first.load, align 4 | ||
| // %min.idx.ext = sext i32 %min.idx to i64 | ||
| // %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext | ||
| // %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4 | ||
| // %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4 | ||
| // ... | ||
| // ... | ||
| // br i1 %cond, label %loop, label %exit | ||
| // | ||
| // *** After transformation *** | ||
| // | ||
| // preheader: | ||
| // %min.idx.ext = sext i32 %min.idx.ext to i64 | ||
| // %hoist_gep1 = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext | ||
| // %hoist_gep2 = getelementptr i8, ptr %hoist_gep1, i64 -4 | ||
| // %hoisted_load = load <TYPE>, ptr %hoist_gep2, align 4 | ||
| // br label %loop | ||
| // | ||
| // loop: | ||
| // %val.first = load <TYPE>, ptr %ptr.first.load, align 4 | ||
| // ... | ||
| // (new) %val.current.min = select i1 %cond, <TYPE> %hoisted_load, <TYPE> | ||
| // %val.current.min | ||
| // ... | ||
| // ... | ||
| // br i1 %cond, label %loop, label %exit | ||
| bool GVNPass::transformMinFindingSelectPattern( | ||
| Loop *L, Type *LoadType, BasicBlock *Preheader, BasicBlock *BB, Value *LHS, | ||
| Value *LoadVal, CmpInst *Comparison, SelectInst *Select, Value *BasePtr, | ||
| Value *IndexVal, Value *OffsetVal) { | ||
|
|
||
| assert(IndexVal && "IndexVal is null"); | ||
| AAResults *AA = VN.getAliasAnalysis(); | ||
| assert(AA && "AA is null"); | ||
|
|
||
| IRBuilder<> Builder(Preheader->getTerminator()); | ||
| Value *InitialMinIndex = | ||
| dyn_cast<PHINode>(IndexVal)->getIncomingValueForBlock(Preheader); | ||
|
|
||
| // Insert PHI node at the top of this block. | ||
| // This PHI node will be used to memoize the current minimum value so far. | ||
| PHINode *KnownMinPhi = PHINode::Create(LoadType, 2, "known_min", BB->begin()); | ||
|
|
||
| // Hoist the load and build the necessary operations. | ||
| // 1. hoist_0 = sext i32 1 to i64 | ||
| Value *HoistedSExt = | ||
| Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sext i64 seems to make assumptions about the types that are involved without ever checking them. I expect assertion failures if you use different types. Ideally this should be type independent.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, added checks. |
||
|
|
||
| // 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt | ||
| Value *HoistedGEP1 = | ||
| Builder.CreateGEP(LoadType, BasePtr, HoistedSExt, "hoist_gep1"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to make an assumption that the load type and the GEP type are the same, which is not necessarily true.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm..Capturing GEP in m_match is bit tricky here. How do I add such check? |
||
|
|
||
| // 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal | ||
| Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1, | ||
| OffsetVal, "hoist_gep2"); | ||
|
|
||
| MemoryLocation NewLoc = MemoryLocation( | ||
| HoistedGEP2, | ||
| LocationSize::precise( | ||
| L->getHeader()->getDataLayout().getTypeStoreSize(LoadType))); | ||
| // Check if any instruction in the loop clobbers this location. | ||
| bool CanHoist = true; | ||
| for (BasicBlock *BB : L->blocks()) { | ||
| for (Instruction &I : *BB) { | ||
| if (I.mayWriteToMemory()) { | ||
| // Check if this instruction may clobber our hoisted load. | ||
| ModRefInfo MRI = AA->getModRefInfo(&I, NewLoc); | ||
| if (isModOrRefSet(MRI)) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - may be clobbered by: " << I | ||
| << "\n"); | ||
| CanHoist = false; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| if (!CanHoist) | ||
| break; | ||
| } | ||
| if (!CanHoist) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - may be clobbered by some " | ||
| "instruction in the loop.\n"); | ||
| return false; | ||
| } | ||
|
|
||
| // 4. hoisted_load = load float, ptr HoistedGEP2 | ||
| LoadInst *NewLoad = Builder.CreateLoad(LoadType, HoistedGEP2, "hoisted_load"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to hoist loads, shouldn't we be querying MDA about clobbers somewhere?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And we probably also need to update MSSA.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. Added the loop for checking clobbering. |
||
|
|
||
| // Let the new load now take the place of the old load. | ||
| LoadVal->replaceAllUsesWith(NewLoad); | ||
| dyn_cast<LoadInst>(LoadVal)->eraseFromParent(); | ||
|
|
||
| // Comparison should now compare the current value and the newly inserted | ||
| // PHI node. | ||
| Comparison->setOperand(1, KnownMinPhi); | ||
|
|
||
| // Create new select instruction for selecting the minimum value. | ||
| IRBuilder<> SelectBuilder(BB->getTerminator()); | ||
| SelectInst *CurrentMinSelect = dyn_cast<SelectInst>( | ||
| SelectBuilder.CreateSelect(Comparison, LHS, KnownMinPhi, "current_min")); | ||
|
|
||
| // Populate the newly created PHI node | ||
| // with (hoisted) NewLoad from the preheader and CurrentMinSelect. | ||
| KnownMinPhi->addIncoming(NewLoad, Preheader); | ||
| KnownMinPhi->addIncoming(CurrentMinSelect, BB); | ||
|
|
||
| if (MSSAU) { | ||
| auto *OrigUse = | ||
| MSSAU->getMemorySSA()->getMemoryAccess(dyn_cast<Instruction>(LoadVal)); | ||
| if (OrigUse) { | ||
| MemoryAccess *DefiningAccess = OrigUse->getDefiningAccess(); | ||
| MSSAU->createMemoryAccessInBB(NewLoad, DefiningAccess, Preheader, | ||
| MemorySSA::BeforeTerminator); | ||
| } | ||
| } | ||
| LLVM_DEBUG( | ||
| dbgs() << "GVN: Transformed the code for minimum finding pattern.\n"); | ||
| return true; | ||
| } | ||
|
|
||
| // We are looking for the following pattern: | ||
| // loop: | ||
| // ... | ||
| // ... | ||
| // %min.idx = phi i32 [ %initial_min_idx, %entry ], [ %min.idx.next, %loop ] | ||
| // ... | ||
| // %val.first = load <TYPE>, ptr %ptr.first.load, align 4 | ||
| // %min.idx.ext = sext i32 %min.idx to i64 | ||
| // %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext | ||
| // %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4 | ||
| // %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4 | ||
| // %cmp = <CMP_INST> <TYPE> %val.first, %val.current.min | ||
| // ... | ||
| // %min.idx.next = select i1 %cmp, ..., i32 %min.idx | ||
| // ... | ||
| // ... | ||
| // br i1 ..., label %loop, ... | ||
| bool GVNPass::recognizeMinFindingSelectPattern(SelectInst *Select) { | ||
| IRBuilder<> Builder(Select); | ||
| Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr, | ||
| *SExt = nullptr; | ||
| BasicBlock *BB = Select->getParent(); | ||
|
|
||
| // If the block is not in a loop, bail out. | ||
| Loop *L = LI->getLoopFor(BB); | ||
| if (!L) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Could not find loop.\n"); | ||
| return false; | ||
| } | ||
|
|
||
| // If preheader of the loop is not found, bail out. | ||
| BasicBlock *Preheader = L->getLoopPreheader(); | ||
| if (!Preheader) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader.\n"); | ||
| return false; | ||
| } | ||
| Value *Condition = Select->getCondition(); | ||
| CmpInst *Comparison = dyn_cast<CmpInst>(Condition); | ||
| if (!Comparison) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison.\n"); | ||
| return false; | ||
| } | ||
|
|
||
| // Check if this is less-than comparison. | ||
| CmpInst::Predicate Pred = Comparison->getPredicate(); | ||
| if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT && | ||
| Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means that you will fail to handle the pattern where
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, if I got this. Do you mean I should swap, get Swapped predicate and then go ahead? |
||
| LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred | ||
| << "\n"); | ||
| return false; | ||
| } | ||
|
|
||
| // Check that both operands are loads. | ||
| Value *LHS = Comparison->getOperand(0); | ||
| Value *RHS = Comparison->getOperand(1); | ||
| if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads.\n"); | ||
| return false; | ||
| } | ||
|
|
||
| if (!match(RHS, m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_Value(SExt)), | ||
| m_Value(OffsetVal))))) { | ||
| LLVM_DEBUG(dbgs() << "GVN: Not a required load pattern.\n"); | ||
| return false; | ||
| } | ||
| // Check if the SExt instruction is a sext instruction. | ||
| SExtInst *SEInst = dyn_cast<SExtInst>(SExt); | ||
| if (!SEInst) { | ||
| LLVM_DEBUG(dbgs() << "GVN: not a sext instruction.\n"); | ||
| return false; | ||
| } | ||
| // Check if the "To" and "from" type of the sext instruction are i64 and i32 | ||
| // respectively. | ||
| if (SEInst->getType() != Builder.getInt64Ty() || | ||
| SEInst->getOperand(0)->getType() != Builder.getInt32Ty()) { | ||
| LLVM_DEBUG( | ||
| dbgs() | ||
| << "GVN: Not matching the required type for sext instruction.\n"); | ||
| return false; | ||
| } | ||
|
|
||
| IndexVal = SEInst->getOperand(0); | ||
| // Check if the IndexVal is a PHI node. | ||
| PHINode *Phi = dyn_cast<PHINode>(IndexVal); | ||
| if (!Phi) { | ||
| LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n"); | ||
| return false; | ||
| } | ||
|
|
||
| LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: " | ||
| << Select->getParent()->getName() << ".\n"); | ||
|
|
||
| return transformMinFindingSelectPattern(L, dyn_cast<LoadInst>(LHS)->getType(), | ||
| Preheader, BB, LHS, RHS, Comparison, | ||
| Select, BasePtr, IndexVal, OffsetVal); | ||
| } | ||
|
|
||
| class llvm::gvn::GVNLegacyPass : public FunctionPass { | ||
| public: | ||
| static char ID; // Pass identification, replacement for typeid. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can forward declare
Loopinstead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.