Skip to content

Commit 6f391b1

Browse files
committed
Added trailing zeros counting pattern recognition.
1 parent c7eede5 commit 6f391b1

File tree

2 files changed

+369
-1
lines changed

2 files changed

+369
-1
lines changed

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ class LoopIdiomRecognize {
243243
bool recognizeShiftUntilBitTest();
244244
bool recognizeShiftUntilZero();
245245

246+
bool recognizeAndInsertCtz();
247+
void transformLoopToCtz(BasicBlock *PreCondBB, Instruction *CntInst,
248+
PHINode *CntPhi, Value *Var);
249+
246250
/// @}
247251
};
248252
} // end anonymous namespace
@@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
14841488
<< CurLoop->getHeader()->getName() << "\n");
14851489

14861490
return recognizePopcount() || recognizeAndInsertFFS() ||
1487-
recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
1491+
recognizeShiftUntilBitTest() || recognizeShiftUntilZero() ||
1492+
recognizeAndInsertCtz();
14881493
}
14891494

14901495
/// Check if the given conditional branch is based on the comparison between
@@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
28682873
++NumShiftUntilZero;
28692874
return MadeChange;
28702875
}
2876+
2877+
// This function recognizes a loop that counts the number of trailing zeros
2878+
// loop:
2879+
// %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
2880+
// %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ]
2881+
// %add = add nuw nsw i32 %count.010, 1
2882+
// %shr = ashr exact i32 %n.addr.09, 1
2883+
// %0 = and i32 %n.addr.09, 2
2884+
// %cmp1 = icmp eq i32 %0, 0
2885+
// br i1 %cmp1, label %while.body, label %if.end.loopexit
2886+
static bool detectShiftUntilZeroAndOneIdiom(Loop *CurLoop, Value *&InitX,
2887+
Instruction *&CntInst,
2888+
PHINode *&CntPhi) {
2889+
BasicBlock *LoopEntry;
2890+
Value *VarX;
2891+
Instruction *DefX;
2892+
2893+
CntInst = nullptr;
2894+
CntPhi = nullptr;
2895+
LoopEntry = *(CurLoop->block_begin());
2896+
2897+
// Check if the loop-back branch is in desirable form.
2898+
// "if (x == 0) goto loop-entry"
2899+
if (Value *T = matchCondition(
2900+
dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry, true)) {
2901+
DefX = dyn_cast<Instruction>(T);
2902+
} else {
2903+
LLVM_DEBUG(dbgs() << "Bad condition for branch instruction\n");
2904+
return false;
2905+
}
2906+
2907+
// operand compares with 2, because we are looking for "x & 2"
2908+
// which was optimized by previous passes from "(x >> 1) & 1"
2909+
2910+
if (!match(DefX, m_c_And(PatternMatch::m_Value(VarX),
2911+
PatternMatch::m_SpecificInt(2))))
2912+
return false;
2913+
2914+
// check if VarX is a phi node
2915+
2916+
auto *PhiX = dyn_cast<PHINode>(VarX);
2917+
2918+
if (!PhiX || PhiX->getParent() != LoopEntry)
2919+
return false;
2920+
2921+
Instruction *DefXRShift = nullptr;
2922+
2923+
// check if PhiX has a shift instruction as a operand, which is a "x >> 1"
2924+
2925+
for (int i = 0; i < 2; ++i) {
2926+
if (auto *Inst = dyn_cast<Instruction>(PhiX->getOperand(i))) {
2927+
if (Inst->getOpcode() == Instruction::AShr ||
2928+
Inst->getOpcode() == Instruction::LShr) {
2929+
DefXRShift = Inst;
2930+
break;
2931+
}
2932+
}
2933+
}
2934+
2935+
if (DefXRShift == nullptr)
2936+
return false;
2937+
2938+
// check if the shift instruction is a "x >> 1"
2939+
auto *Shft = dyn_cast<ConstantInt>(DefXRShift->getOperand(1));
2940+
if (!Shft || !Shft->isOne())
2941+
return false;
2942+
2943+
if (DefXRShift->getOperand(0) != VarX)
2944+
return false;
2945+
2946+
InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader());
2947+
2948+
// Find the instruction which counts the trailing zeros: cnt.next = cnt + 1.
2949+
for (Instruction &Inst : llvm::make_range(
2950+
LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
2951+
if (Inst.getOpcode() != Instruction::Add)
2952+
continue;
2953+
2954+
ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
2955+
if (!Inc || !Inc->isOne())
2956+
continue;
2957+
2958+
PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
2959+
if (!Phi)
2960+
continue;
2961+
2962+
CntInst = &Inst;
2963+
CntPhi = Phi;
2964+
break;
2965+
}
2966+
if (!CntInst)
2967+
return false;
2968+
2969+
return true;
2970+
}
2971+
2972+
/// Recognize CTTZ idiom in a non-countable loop and convert it to countable
2973+
/// with CTTZ of variable as a trip count. If CTTZ was inserted, returns true;
2974+
/// otherwise, returns false.
2975+
///
2976+
// int count_trailing_zeroes(uint32_t n) {
2977+
// int count = 0;
2978+
// if (n == 0){
2979+
// return 32;
2980+
// }
2981+
// while ((n & 1) == 0) {
2982+
// count += 1;
2983+
// n >>= 1;
2984+
// }
2985+
//
2986+
//
2987+
// return count;
2988+
// }
2989+
bool LoopIdiomRecognize::recognizeAndInsertCtz() {
2990+
// Give up if the loop has multiple blocks or multiple backedges.
2991+
if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
2992+
return false;
2993+
2994+
Value *InitX;
2995+
PHINode *CntPhi = nullptr;
2996+
Instruction *CntInst = nullptr;
2997+
// For counting trailing zeros with uncountable loop idiom, transformation is
2998+
// always profitable if IdiomCanonicalSize is 7.
2999+
const size_t IdiomCanonicalSize = 7;
3000+
3001+
if (!detectShiftUntilZeroAndOneIdiom(CurLoop, InitX, CntInst, CntPhi))
3002+
return false;
3003+
3004+
BasicBlock *PH = CurLoop->getLoopPreheader();
3005+
3006+
auto *PreCondBB = PH->getSinglePredecessor();
3007+
if (!PreCondBB)
3008+
return false;
3009+
auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
3010+
if (!PreCondBI)
3011+
return false;
3012+
3013+
// check that initial value is not zero and "(init & 1) == 0"
3014+
// initial value must not be zero, because it will cause infinite loop
3015+
// without this check, after replacing the loop with cttz, the counter will be
3016+
// size of int, while before the replacement the loop would have executed
3017+
// indefinitely
3018+
3019+
// match that case, where n is initial value
3020+
// entry:
3021+
// %cmp.not = icmp eq i32 %n, 0
3022+
// br i1 %cmp.not, label %cleanup, label %while.cond.preheader
3023+
//
3024+
// while.cond.preheader:
3025+
// %and5 = and i32 %n, 1
3026+
// %cmp16 = icmp eq i32 %and5, 0
3027+
// br i1 %cmp16, label %while.body.preheader, label %cleanup
3028+
3029+
Value *PreCond = matchCondition(PreCondBI, PH, true);
3030+
3031+
if (!PreCond)
3032+
return false;
3033+
3034+
Value *InitPredX = nullptr;
3035+
if (!match(PreCond, m_c_And(PatternMatch::m_Value(InitPredX),
3036+
PatternMatch::m_One())) ||
3037+
InitPredX != InitX)
3038+
return false;
3039+
auto *PrePreCondBB = PreCondBB->getSinglePredecessor();
3040+
if (!PrePreCondBB)
3041+
return false;
3042+
auto *PrePreCondBI = dyn_cast<BranchInst>(PrePreCondBB->getTerminator());
3043+
if (!PrePreCondBI)
3044+
return false;
3045+
if (matchCondition(PrePreCondBI, PreCondBB) != InitX)
3046+
return false;
3047+
3048+
// CTTZ intrinsic always profitable after deleting the loop.
3049+
// the loop has only 7 instructions:
3050+
3051+
// @llvm.dbg doesn't count as they have no semantic effect.
3052+
auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
3053+
uint32_t HeaderSize =
3054+
std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
3055+
if (HeaderSize != IdiomCanonicalSize)
3056+
return false;
3057+
3058+
transformLoopToCtz(PH, CntInst, CntPhi, InitX);
3059+
return true;
3060+
}
3061+
3062+
void LoopIdiomRecognize::transformLoopToCtz(BasicBlock *Preheader,
3063+
Instruction *CntInst,
3064+
PHINode *CntPhi, Value *InitX) {
3065+
BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
3066+
const DebugLoc &DL = CntInst->getDebugLoc();
3067+
3068+
// Insert the CTTZ instruction at the end of the preheader block
3069+
IRBuilder<> Builder(PreheaderBr);
3070+
Builder.SetCurrentDebugLocation(DL);
3071+
Value *Count = createFFSIntrinsic(Builder, InitX, DL,
3072+
/* is zero poison */ true, Intrinsic::cttz);
3073+
3074+
Value *NewCount = Count;
3075+
3076+
NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType());
3077+
3078+
Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader);
3079+
// If the counter was being incremented in the loop, add NewCount to the
3080+
// counter's initial value, but only if the initial value is not zero.
3081+
ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
3082+
if (!InitConst || !InitConst->isZero())
3083+
NewCount = Builder.CreateAdd(NewCount, CntInitVal);
3084+
3085+
BasicBlock *Body = *(CurLoop->block_begin());
3086+
3087+
// All the references to the original counter outside
3088+
// the loop are replaced with the NewCount
3089+
CntInst->replaceUsesOutsideBlock(NewCount, Body);
3090+
SE->forgetLoop(CurLoop);
3091+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=loop-idiom -mtriple=riscv32 -S < %s | FileCheck %s
3+
; RUN: opt -passes=loop-idiom -mtriple=riscv64 -S < %s | FileCheck %s
4+
5+
; Copied from popcnt test.
6+
7+
;To recognize this pattern:
8+
;int ctz(uint32_t n)
9+
;{
10+
; int count = 0;
11+
; if (n == 0)
12+
; {
13+
; return 32;
14+
; }
15+
; while ((n & 1) == 0)
16+
; {
17+
; count += 1;
18+
; n >>= 1;
19+
; }
20+
; return count;
21+
;}
22+
23+
define signext i32 @count_trailing_zeroes(i32 noundef signext %n) local_unnamed_addr #0 {
24+
; CHECK-LABEL: @count_trailing_zeroes(
25+
; CHECK-NEXT: entry:
26+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[N:%.*]], 0
27+
; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
28+
; CHECK: while.cond.preheader:
29+
; CHECK-NEXT: [[AND4:%.*]] = and i32 [[N]], 1
30+
; CHECK-NEXT: [[CMP15:%.*]] = icmp eq i32 [[AND4]], 0
31+
; CHECK-NEXT: br i1 [[CMP15]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
32+
; CHECK: while.body.preheader:
33+
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[N]], i1 true)
34+
; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
35+
; CHECK: while.body:
36+
; CHECK-NEXT: [[COUNT_07:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
37+
; CHECK-NEXT: [[N_ADDR_06:%.*]] = phi i32 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
38+
; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_07]], 1
39+
; CHECK-NEXT: [[SHR]] = lshr i32 [[N_ADDR_06]], 1
40+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[N_ADDR_06]], 2
41+
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[TMP1]], 0
42+
; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
43+
; CHECK: cleanup.loopexit:
44+
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP0]], [[WHILE_BODY]] ]
45+
; CHECK-NEXT: br label [[CLEANUP]]
46+
; CHECK: cleanup:
47+
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 32, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
48+
; CHECK-NEXT: ret i32 [[RETVAL_0]]
49+
;
50+
entry:
51+
%cmp = icmp eq i32 %n, 0
52+
br i1 %cmp, label %cleanup, label %while.cond.preheader
53+
54+
while.cond.preheader: ; preds = %entry
55+
%and4 = and i32 %n, 1
56+
%cmp15 = icmp eq i32 %and4, 0
57+
br i1 %cmp15, label %while.body, label %cleanup
58+
59+
while.body: ; preds = %while.cond.preheader, %while.body
60+
%count.07 = phi i32 [ %add, %while.body ], [ 0, %while.cond.preheader ]
61+
%n.addr.06 = phi i32 [ %shr, %while.body ], [ %n, %while.cond.preheader ]
62+
%add = add nuw nsw i32 %count.07, 1
63+
%shr = lshr i32 %n.addr.06, 1
64+
%0 = and i32 %n.addr.06, 2
65+
%cmp1 = icmp eq i32 %0, 0
66+
br i1 %cmp1, label %while.body, label %cleanup
67+
68+
cleanup: ; preds = %while.body, %while.cond.preheader, %entry
69+
%retval.0 = phi i32 [ 32, %entry ], [ 0, %while.cond.preheader ], [ %add, %while.body ]
70+
ret i32 %retval.0
71+
}
72+
73+
;int ctz(uint64_t n)
74+
;{
75+
; int count = 0;
76+
; if (n != 0)
77+
; {
78+
; while ((n & 1) == 0)
79+
; {
80+
; n >>= 1;
81+
; count += 1;
82+
; }
83+
; }
84+
; else
85+
; {
86+
; return 64;
87+
; }
88+
; return count;
89+
;}
90+
91+
define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr {
92+
; CHECK-LABEL: @ctz(
93+
; CHECK-NEXT: entry:
94+
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[N:%.*]], 0
95+
; CHECK-NEXT: br i1 [[CMP_NOT]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
96+
; CHECK: while.cond.preheader:
97+
; CHECK-NEXT: [[AND5:%.*]] = and i64 [[N]], 1
98+
; CHECK-NEXT: [[CMP16:%.*]] = icmp eq i64 [[AND5]], 0
99+
; CHECK-NEXT: br i1 [[CMP16]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
100+
; CHECK: while.body.preheader:
101+
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.cttz.i64(i64 [[N]], i1 true)
102+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
103+
; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
104+
; CHECK: while.body:
105+
; CHECK-NEXT: [[COUNT_08:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
106+
; CHECK-NEXT: [[N_ADDR_07:%.*]] = phi i64 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
107+
; CHECK-NEXT: [[SHR]] = lshr i64 [[N_ADDR_07]], 1
108+
; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_08]], 1
109+
; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[N_ADDR_07]], 2
110+
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[TMP2]], 0
111+
; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
112+
; CHECK: cleanup.loopexit:
113+
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP1]], [[WHILE_BODY]] ]
114+
; CHECK-NEXT: br label [[CLEANUP]]
115+
; CHECK: cleanup:
116+
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 64, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
117+
; CHECK-NEXT: ret i32 [[RETVAL_0]]
118+
;
119+
entry:
120+
%cmp.not = icmp eq i64 %n, 0
121+
br i1 %cmp.not, label %cleanup, label %while.cond.preheader
122+
123+
while.cond.preheader: ; preds = %entry
124+
%and5 = and i64 %n, 1
125+
%cmp16 = icmp eq i64 %and5, 0
126+
br i1 %cmp16, label %while.body.preheader, label %cleanup
127+
128+
while.body.preheader: ; preds = %while.cond.preheader
129+
br label %while.body
130+
131+
while.body: ; preds = %while.body.preheader, %while.body
132+
%count.08 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
133+
%n.addr.07 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
134+
%shr = lshr i64 %n.addr.07, 1
135+
%add = add nuw nsw i32 %count.08, 1
136+
%0 = and i64 %n.addr.07, 2
137+
%cmp1 = icmp eq i64 %0, 0
138+
br i1 %cmp1, label %while.body, label %cleanup.loopexit
139+
140+
cleanup.loopexit: ; preds = %while.body
141+
%add.lcssa = phi i32 [ %add, %while.body ]
142+
br label %cleanup
143+
144+
cleanup: ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
145+
%retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
146+
ret i32 %retval.0
147+
}

0 commit comments

Comments
 (0)