Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 222 additions & 1 deletion llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ class LoopIdiomRecognize {
bool recognizeShiftUntilBitTest();
bool recognizeShiftUntilZero();

bool recognizeAndInsertCtz();
void transformLoopToCtz(BasicBlock *PreCondBB, Instruction *CntInst,
PHINode *CntPhi, Value *Var);

/// @}
};
} // end anonymous namespace
Expand Down Expand Up @@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
<< CurLoop->getHeader()->getName() << "\n");

return recognizePopcount() || recognizeAndInsertFFS() ||
recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
recognizeShiftUntilBitTest() || recognizeShiftUntilZero() ||
recognizeAndInsertCtz();
}

/// Check if the given conditional branch is based on the comparison between
Expand Down Expand Up @@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
++NumShiftUntilZero;
return MadeChange;
}

// This function recognizes a loop that counts the number of trailing zeros
// loop:
// %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
// %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ]
// %add = add nuw nsw i32 %count.010, 1
// %shr = ashr exact i32 %n.addr.09, 1
// %0 = and i32 %n.addr.09, 2
// %cmp1 = icmp eq i32 %0, 0
// br i1 %cmp1, label %while.body, label %if.end.loopexit
static bool detectShiftUntilZeroAndOneIdiom(Loop *CurLoop, Value *&InitX,
Instruction *&CntInst,
PHINode *&CntPhi) {
BasicBlock *LoopEntry;
Value *VarX;
Instruction *DefX;

CntInst = nullptr;
CntPhi = nullptr;
LoopEntry = *(CurLoop->block_begin());

// Check if the loop-back branch is in desirable form.
// "if (x == 0) goto loop-entry"
if (Value *T = matchCondition(
dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry, true)) {
DefX = dyn_cast<Instruction>(T);
} else {
LLVM_DEBUG(dbgs() << "Bad condition for branch instruction\n");
return false;
}

// operand compares with 2, because we are looking for "x & 2"
// which was optimized by previous passes from "(x >> 1) & 1"

if (!match(DefX, m_c_And(PatternMatch::m_Value(VarX),
PatternMatch::m_SpecificInt(2))))
return false;

// check if VarX is a phi node

auto *PhiX = dyn_cast<PHINode>(VarX);

if (!PhiX || PhiX->getParent() != LoopEntry)
return false;

Instruction *DefXRShift = nullptr;

// check if PhiX has a shift instruction as a operand, which is a "x >> 1"

for (int i = 0; i < 2; ++i) {
if (auto *Inst = dyn_cast<Instruction>(PhiX->getOperand(i))) {
if (Inst->getOpcode() == Instruction::AShr ||
Inst->getOpcode() == Instruction::LShr) {
DefXRShift = Inst;
break;
}
}
}

if (DefXRShift == nullptr)
return false;

// check if the shift instruction is a "x >> 1"
auto *Shft = dyn_cast<ConstantInt>(DefXRShift->getOperand(1));
if (!Shft || !Shft->isOne())
return false;

if (DefXRShift->getOperand(0) != VarX)
return false;

InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader());

// Find the instruction which counts the trailing zeros: cnt.next = cnt + 1.
for (Instruction &Inst : llvm::make_range(
LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
if (Inst.getOpcode() != Instruction::Add)
continue;

ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
if (!Inc || !Inc->isOne())
continue;

PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
if (!Phi)
continue;

CntInst = &Inst;
CntPhi = Phi;
break;
}
if (!CntInst)
return false;

return true;
}

/// Recognize CTTZ idiom in a non-countable loop and convert it to countable
/// with CTTZ of variable as a trip count. If CTTZ was inserted, returns true;
/// otherwise, returns false.
///
// int count_trailing_zeroes(uint32_t n) {
// int count = 0;
// if (n == 0){
// return 32;
// }
// while ((n & 1) == 0) {
// count += 1;
// n >>= 1;
// }
//
//
// return count;
// }
bool LoopIdiomRecognize::recognizeAndInsertCtz() {
// Give up if the loop has multiple blocks or multiple backedges.
if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
return false;

Value *InitX;
PHINode *CntPhi = nullptr;
Instruction *CntInst = nullptr;
// For counting trailing zeros with uncountable loop idiom, transformation is
// always profitable if IdiomCanonicalSize is 7.
const size_t IdiomCanonicalSize = 7;

if (!detectShiftUntilZeroAndOneIdiom(CurLoop, InitX, CntInst, CntPhi))
return false;

BasicBlock *PH = CurLoop->getLoopPreheader();

auto *PreCondBB = PH->getSinglePredecessor();
if (!PreCondBB)
return false;
auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
if (!PreCondBI)
return false;

// check that initial value is not zero and "(init & 1) == 0"
// initial value must not be zero, because it will cause infinite loop
// without this check, after replacing the loop with cttz, the counter will be
// size of int, while before the replacement the loop would have executed
// indefinitely

// match that case, where n is initial value
// entry:
// %cmp.not = icmp eq i32 %n, 0
// br i1 %cmp.not, label %cleanup, label %while.cond.preheader
//
// while.cond.preheader:
// %and5 = and i32 %n, 1
// %cmp16 = icmp eq i32 %and5, 0
// br i1 %cmp16, label %while.body.preheader, label %cleanup

Value *PreCond = matchCondition(PreCondBI, PH, true);

if (!PreCond)
return false;

Value *InitPredX = nullptr;
if (!match(PreCond, m_c_And(PatternMatch::m_Value(InitPredX),
PatternMatch::m_One())) ||
InitPredX != InitX)
return false;
auto *PrePreCondBB = PreCondBB->getSinglePredecessor();
if (!PrePreCondBB)
return false;
auto *PrePreCondBI = dyn_cast<BranchInst>(PrePreCondBB->getTerminator());
if (!PrePreCondBI)
return false;
if (matchCondition(PrePreCondBI, PreCondBB) != InitX)
return false;

// CTTZ intrinsic always profitable after deleting the loop.
// the loop has only 7 instructions:

// @llvm.dbg doesn't count as they have no semantic effect.
auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
uint32_t HeaderSize =
std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
if (HeaderSize != IdiomCanonicalSize)
return false;

transformLoopToCtz(PH, CntInst, CntPhi, InitX);
return true;
}

void LoopIdiomRecognize::transformLoopToCtz(BasicBlock *Preheader,
Instruction *CntInst,
PHINode *CntPhi, Value *InitX) {
BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
const DebugLoc &DL = CntInst->getDebugLoc();

// Insert the CTTZ instruction at the end of the preheader block
IRBuilder<> Builder(PreheaderBr);
Builder.SetCurrentDebugLocation(DL);
Value *Count = createFFSIntrinsic(Builder, InitX, DL,
/* is zero poison */ true, Intrinsic::cttz);

Value *NewCount = Count;

NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType());

Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader);
// If the counter was being incremented in the loop, add NewCount to the
// counter's initial value, but only if the initial value is not zero.
ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
if (!InitConst || !InitConst->isZero())
NewCount = Builder.CreateAdd(NewCount, CntInitVal);

BasicBlock *Body = *(CurLoop->block_begin());

// All the references to the original counter outside
// the loop are replaced with the NewCount
CntInst->replaceUsesOutsideBlock(NewCount, Body);
SE->forgetLoop(CurLoop);
}
147 changes: 147 additions & 0 deletions llvm/test/Transforms/LoopIdiom/RISCV/cttz.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=loop-idiom -mtriple=riscv32 -S < %s | FileCheck %s
; RUN: opt -passes=loop-idiom -mtriple=riscv64 -S < %s | FileCheck %s

; Copied from popcnt test.

;To recognize this pattern:
;int ctz(uint32_t n)
;{
; int count = 0;
; if (n == 0)
; {
; return 32;
; }
; while ((n & 1) == 0)
; {
; count += 1;
; n >>= 1;
; }
; return count;
;}

define signext i32 @count_trailing_zeroes(i32 noundef signext %n) local_unnamed_addr #0 {
; CHECK-LABEL: @count_trailing_zeroes(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[N:%.*]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
; CHECK: while.cond.preheader:
; CHECK-NEXT: [[AND4:%.*]] = and i32 [[N]], 1
; CHECK-NEXT: [[CMP15:%.*]] = icmp eq i32 [[AND4]], 0
; CHECK-NEXT: br i1 [[CMP15]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
; CHECK: while.body.preheader:
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[N]], i1 true)
; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
; CHECK: while.body:
; CHECK-NEXT: [[COUNT_07:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
; CHECK-NEXT: [[N_ADDR_06:%.*]] = phi i32 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_07]], 1
; CHECK-NEXT: [[SHR]] = lshr i32 [[N_ADDR_06]], 1
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[N_ADDR_06]], 2
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[TMP1]], 0
; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
; CHECK: cleanup.loopexit:
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP0]], [[WHILE_BODY]] ]
; CHECK-NEXT: br label [[CLEANUP]]
; CHECK: cleanup:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 32, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp = icmp eq i32 %n, 0
br i1 %cmp, label %cleanup, label %while.cond.preheader

while.cond.preheader: ; preds = %entry
%and4 = and i32 %n, 1
%cmp15 = icmp eq i32 %and4, 0
br i1 %cmp15, label %while.body, label %cleanup

while.body: ; preds = %while.cond.preheader, %while.body
%count.07 = phi i32 [ %add, %while.body ], [ 0, %while.cond.preheader ]
%n.addr.06 = phi i32 [ %shr, %while.body ], [ %n, %while.cond.preheader ]
%add = add nuw nsw i32 %count.07, 1
%shr = lshr i32 %n.addr.06, 1
%0 = and i32 %n.addr.06, 2
%cmp1 = icmp eq i32 %0, 0
br i1 %cmp1, label %while.body, label %cleanup

cleanup: ; preds = %while.body, %while.cond.preheader, %entry
%retval.0 = phi i32 [ 32, %entry ], [ 0, %while.cond.preheader ], [ %add, %while.body ]
ret i32 %retval.0
}

;int ctz(uint64_t n)
;{
; int count = 0;
; if (n != 0)
; {
; while ((n & 1) == 0)
; {
; n >>= 1;
; count += 1;
; }
; }
; else
; {
; return 64;
; }
; return count;
;}

define dso_local signext i32 @ctz(i64 noundef %n) local_unnamed_addr {
; CHECK-LABEL: @ctz(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[N:%.*]], 0
; CHECK-NEXT: br i1 [[CMP_NOT]], label [[CLEANUP:%.*]], label [[WHILE_COND_PREHEADER:%.*]]
; CHECK: while.cond.preheader:
; CHECK-NEXT: [[AND5:%.*]] = and i64 [[N]], 1
; CHECK-NEXT: [[CMP16:%.*]] = icmp eq i64 [[AND5]], 0
; CHECK-NEXT: br i1 [[CMP16]], label [[WHILE_BODY_PREHEADER:%.*]], label [[CLEANUP]]
; CHECK: while.body.preheader:
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.cttz.i64(i64 [[N]], i1 true)
; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
; CHECK-NEXT: br label [[WHILE_BODY:%.*]]
; CHECK: while.body:
; CHECK-NEXT: [[COUNT_08:%.*]] = phi i32 [ [[ADD:%.*]], [[WHILE_BODY]] ], [ 0, [[WHILE_BODY_PREHEADER]] ]
; CHECK-NEXT: [[N_ADDR_07:%.*]] = phi i64 [ [[SHR:%.*]], [[WHILE_BODY]] ], [ [[N]], [[WHILE_BODY_PREHEADER]] ]
; CHECK-NEXT: [[SHR]] = lshr i64 [[N_ADDR_07]], 1
; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[COUNT_08]], 1
; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[N_ADDR_07]], 2
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[TMP2]], 0
; CHECK-NEXT: br i1 [[CMP1]], label [[WHILE_BODY]], label [[CLEANUP_LOOPEXIT:%.*]]
; CHECK: cleanup.loopexit:
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i32 [ [[TMP1]], [[WHILE_BODY]] ]
; CHECK-NEXT: br label [[CLEANUP]]
; CHECK: cleanup:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ 64, [[ENTRY:%.*]] ], [ 0, [[WHILE_COND_PREHEADER]] ], [ [[ADD_LCSSA]], [[CLEANUP_LOOPEXIT]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
%cmp.not = icmp eq i64 %n, 0
br i1 %cmp.not, label %cleanup, label %while.cond.preheader

while.cond.preheader: ; preds = %entry
%and5 = and i64 %n, 1
%cmp16 = icmp eq i64 %and5, 0
br i1 %cmp16, label %while.body.preheader, label %cleanup

while.body.preheader: ; preds = %while.cond.preheader
br label %while.body

while.body: ; preds = %while.body.preheader, %while.body
%count.08 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
%n.addr.07 = phi i64 [ %shr, %while.body ], [ %n, %while.body.preheader ]
%shr = lshr i64 %n.addr.07, 1
%add = add nuw nsw i32 %count.08, 1
%0 = and i64 %n.addr.07, 2
%cmp1 = icmp eq i64 %0, 0
br i1 %cmp1, label %while.body, label %cleanup.loopexit

cleanup.loopexit: ; preds = %while.body
%add.lcssa = phi i32 [ %add, %while.body ]
br label %cleanup

cleanup: ; preds = %cleanup.loopexit, %while.cond.preheader, %entry
%retval.0 = phi i32 [ 64, %entry ], [ 0, %while.cond.preheader ], [ %add.lcssa, %cleanup.loopexit ]
ret i32 %retval.0
}