Skip to content

Commit d78d773

Browse files
committed
[GVN] Support rnflow pattern matching and transform
1 parent c9b4169 commit d78d773

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

llvm/include/llvm/Transforms/Scalar/GVN.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/IR/Dominators.h"
2323
#include "llvm/IR/InstrTypes.h"
2424
#include "llvm/IR/PassManager.h"
25+
#include "llvm/Analysis/LoopInfo.h"
2526
#include "llvm/IR/ValueHandle.h"
2627
#include "llvm/Support/Allocator.h"
2728
#include "llvm/Support/Compiler.h"
@@ -45,6 +46,7 @@ class FunctionPass;
4546
class GetElementPtrInst;
4647
class ImplicitControlFlowTracking;
4748
class LoadInst;
49+
class SelectInst;
4850
class LoopInfo;
4951
class MemDepResult;
5052
class MemoryAccess;
@@ -405,6 +407,8 @@ class GVNPass : public PassInfoMixin<GVNPass> {
405407
void addDeadBlock(BasicBlock *BB);
406408
void assignValNumForDeadCode();
407409
void assignBlockRPONumber(Function &F);
410+
411+
bool optimizeMinMaxFindingSelectPattern(SelectInst *Select);
408412
};
409413

410414
/// Create a legacy GVN pass.

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,10 @@ bool GVNPass::processInstruction(Instruction *I) {
27432743
}
27442744
return Changed;
27452745
}
2746+
if (SelectInst *Select = dyn_cast<SelectInst>(I)) {
2747+
if (optimizeMinMaxFindingSelectPattern(Select))
2748+
return true;
2749+
}
27462750

27472751
// Instructions with void type don't return a value, so there's
27482752
// no point in trying to find redundancies in them.
@@ -3330,6 +3334,124 @@ void GVNPass::assignValNumForDeadCode() {
33303334
}
33313335
}
33323336

3337+
bool GVNPass::optimizeMinMaxFindingSelectPattern(SelectInst *Select) {
3338+
LLVM_DEBUG(
3339+
dbgs()
3340+
<< "GVN: Analyzing select instruction for minimum finding pattern\n");
3341+
LLVM_DEBUG(dbgs() << "GVN: Select: " << *Select << "\n");
3342+
Value *Condition = Select->getCondition();
3343+
CmpInst *Comparison = dyn_cast<CmpInst>(Condition);
3344+
if (!Comparison) {
3345+
LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison\n");
3346+
return false;
3347+
}
3348+
3349+
// Check if this is ULT comparison.
3350+
CmpInst::Predicate Pred = Comparison->getPredicate();
3351+
if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT &&
3352+
Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) {
3353+
LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred
3354+
<< "\n");
3355+
return false;
3356+
}
3357+
3358+
// Check that both operands are loads.
3359+
Value *LHS = Comparison->getOperand(0);
3360+
Value *RHS = Comparison->getOperand(1);
3361+
if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) {
3362+
LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads\n");
3363+
return false;
3364+
}
3365+
3366+
LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: "
3367+
<< Select->getParent()->getName() << "\n");
3368+
3369+
// Transform the pattern.
3370+
// Hoist the chain of operations for the second load to preheader.
3371+
// Get predecessor of the block containing the select instruction.
3372+
BasicBlock *BB = Select->getParent();
3373+
3374+
// Get preheader of the loop.
3375+
Loop *L = LI->getLoopFor(BB);
3376+
if (!L) {
3377+
LLVM_DEBUG(dbgs() << "GVN: Could not find loop\n");
3378+
return false;
3379+
}
3380+
BasicBlock *Preheader = L->getLoopPreheader();
3381+
if (!Preheader) {
3382+
LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader\n");
3383+
return false;
3384+
}
3385+
3386+
// Hoist the chain of operations for the second load to preheader.
3387+
// %90 = sext i32 %.05.i to i64
3388+
// %91 = getelementptr float, ptr %0, i64 %90 ; %0 + (sext i32 %85 to i64)*4
3389+
// %92 = getelementptr i8, ptr %91, i64 -4 ; %0 + (sext i32 %85 to i64)*4 - 4
3390+
// %93 = load float, ptr %92, align 4
3391+
3392+
Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr;
3393+
IRBuilder<> Builder(Preheader->getTerminator());
3394+
if (match(RHS,
3395+
m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_SExt(m_Value(IndexVal))),
3396+
m_Value(OffsetVal))))) {
3397+
LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << *RHS << "\n");
3398+
LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << "\n");
3399+
3400+
PHINode *Phi = dyn_cast<PHINode>(IndexVal);
3401+
if (!Phi) {
3402+
LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n");
3403+
return false;
3404+
}
3405+
Value *InitialMinIndex = Phi->getIncomingValueForBlock(Preheader);
3406+
3407+
// Insert PHI node at the top of this block.
3408+
PHINode *KnownMinPhi =
3409+
PHINode::Create(Builder.getFloatTy(), 2, "known_min", BB->begin());
3410+
3411+
// Build the GEP chain in the preheader.
3412+
// 1. hoist_0 = sext i32 to i64
3413+
Value *HoistedSExt =
3414+
Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext");
3415+
3416+
// 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt
3417+
Value *HoistedGEP1 = Builder.CreateGEP(Builder.getFloatTy(), BasePtr,
3418+
HoistedSExt, "hoist_gep1");
3419+
3420+
// 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal
3421+
Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1,
3422+
OffsetVal, "hoist_gep2");
3423+
3424+
// 4. hoisted_load = load float, ptr HoistedGEP2
3425+
LoadInst *NewLoad =
3426+
Builder.CreateLoad(Builder.getFloatTy(), HoistedGEP2, "hoisted_load");
3427+
3428+
// Replace all uses of load with new load.
3429+
RHS->replaceAllUsesWith(NewLoad);
3430+
dyn_cast<LoadInst>(RHS)->eraseFromParent();
3431+
3432+
// Replace second operand of comparison with KnownMinPhi.
3433+
Comparison->setOperand(1, KnownMinPhi);
3434+
3435+
// Create new select instruction for selecting the minimum value.
3436+
IRBuilder<> SelectBuilder(BB->getTerminator());
3437+
SelectInst *CurrentMinSelect =
3438+
dyn_cast<SelectInst>(SelectBuilder.CreateSelect(
3439+
Comparison, LHS, KnownMinPhi, "current_min"));
3440+
3441+
// Populate PHI node.
3442+
KnownMinPhi->addIncoming(NewLoad, Preheader);
3443+
KnownMinPhi->addIncoming(CurrentMinSelect, BB);
3444+
LLVM_DEBUG(dbgs() << "Transformed the code\n");
3445+
return true;
3446+
} else {
3447+
LLVM_DEBUG(dbgs() << "GVN: Could not find pattern: " << *RHS << "\n");
3448+
LLVM_DEBUG(dbgs() << "GVN: Could not find pattern: " << "\n");
3449+
return false;
3450+
}
3451+
return false;
3452+
}
3453+
3454+
33333455
class llvm::gvn::GVNLegacyPass : public FunctionPass {
33343456
public:
33353457
static char ID; // Pass identification, replacement for typeid.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; Minimal test case containing only the .lr.ph.i basic block
3+
; RUN: opt -passes=gvn -S < %s | FileCheck %s
4+
5+
define void @test_lr_ph_i(ptr %0) {
6+
; CHECK-LABEL: define void @test_lr_ph_i(
7+
; CHECK-SAME: ptr [[TMP0:%.*]]) {
8+
; CHECK-NEXT: [[ENTRY:.*]]:
9+
; CHECK-NEXT: [[HOIST_GEP1:%.*]] = getelementptr float, ptr [[TMP0]], i64 1
10+
; CHECK-NEXT: [[HOIST_GEP2:%.*]] = getelementptr i8, ptr [[HOIST_GEP1]], i64 -4
11+
; CHECK-NEXT: [[HOISTED_LOAD:%.*]] = load float, ptr [[HOIST_GEP2]], align 4
12+
; CHECK-NEXT: br label %[[DOTLR_PH_I:.*]]
13+
; CHECK: [[_LR_PH_I:.*:]]
14+
; CHECK-NEXT: [[KNOWN_MIN:%.*]] = phi float [ [[HOISTED_LOAD]], %[[ENTRY]] ], [ [[CURRENT_MIN:%.*]], %[[DOTLR_PH_I]] ]
15+
; CHECK-NEXT: [[INDVARS_IV_I:%.*]] = phi i64 [ 1, %[[ENTRY]] ], [ [[INDVARS_IV_NEXT_I:%.*]], %[[DOTLR_PH_I]] ]
16+
; CHECK-NEXT: [[TMP1:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[TMP10:%.*]], %[[DOTLR_PH_I]] ]
17+
; CHECK-NEXT: [[DOT05_I:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[DOT1_I:%.*]], %[[DOTLR_PH_I]] ]
18+
; CHECK-NEXT: [[INDVARS_IV_NEXT_I]] = add nsw i64 [[INDVARS_IV_I]], -1
19+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[INDVARS_IV_I]]
20+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 -8
21+
; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[TMP3]], align 4
22+
; CHECK-NEXT: [[TMP5:%.*]] = sext i32 [[DOT05_I]] to i64
23+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[TMP5]]
24+
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i64 -4
25+
; CHECK-NEXT: [[TMP8:%.*]] = fcmp contract olt float [[TMP4]], [[KNOWN_MIN]]
26+
; CHECK-NEXT: [[TMP9:%.*]] = trunc nsw i64 [[INDVARS_IV_NEXT_I]] to i32
27+
; CHECK-NEXT: [[DOT1_I]] = select i1 [[TMP8]], i32 [[TMP9]], i32 [[DOT05_I]]
28+
; CHECK-NEXT: [[TMP10]] = add nsw i64 [[TMP1]], -1
29+
; CHECK-NEXT: [[TMP11:%.*]] = icmp samesign ugt i64 [[TMP1]], 1
30+
; CHECK-NEXT: [[CURRENT_MIN]] = select i1 [[TMP8]], float [[TMP4]], float [[KNOWN_MIN]]
31+
; CHECK-NEXT: br i1 [[TMP11]], label %[[DOTLR_PH_I]], label %[[EXIT:.*]]
32+
; CHECK: [[EXIT]]:
33+
; CHECK-NEXT: ret void
34+
;
35+
entry:
36+
br label %.lr.ph.i
37+
38+
.lr.ph.i: ; preds = %.lr.ph.i, %entry
39+
%indvars.iv.i = phi i64 [ 1, %entry ], [ %indvars.iv.next.i, %.lr.ph.i ]
40+
%86 = phi i64 [ 0, %entry ], [ %96, %.lr.ph.i ]
41+
%.05.i = phi i32 [ 1, %entry ], [ %.1.i, %.lr.ph.i ]
42+
%indvars.iv.next.i = add nsw i64 %indvars.iv.i, -1
43+
%87 = getelementptr float, ptr %0, i64 %indvars.iv.i
44+
%88 = getelementptr i8, ptr %87, i64 -8 ; first load : %0 + 4 * 1 - 8
45+
%89 = load float, ptr %88, align 4
46+
%90 = sext i32 %.05.i to i64
47+
%91 = getelementptr float, ptr %0, i64 %90 ; %0 + 4 * 1
48+
%92 = getelementptr i8, ptr %91, i64 -4 ; second load : %0 + 4 * 1 - 4
49+
%93 = load float, ptr %92, align 4
50+
%94 = fcmp contract olt float %89, %93
51+
%95 = trunc nsw i64 %indvars.iv.next.i to i32
52+
%.1.i = select i1 %94, i32 %95, i32 %.05.i
53+
%96 = add nsw i64 %86, -1
54+
%97 = icmp samesign ugt i64 %86, 1
55+
br i1 %97, label %.lr.ph.i, label %exit
56+
57+
exit:
58+
ret void
59+
}

0 commit comments

Comments
 (0)