@@ -243,6 +243,10 @@ class LoopIdiomRecognize {
243243 bool recognizeShiftUntilBitTest ();
244244 bool recognizeShiftUntilZero ();
245245
246+ bool recognizeAndInsertCtz ();
247+ void transformLoopToCtz (BasicBlock *PreCondBB, Instruction *CntInst,
248+ PHINode *CntPhi, Value *Var);
249+
246250 // / @}
247251};
248252} // end anonymous namespace
@@ -1484,7 +1488,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
14841488 << CurLoop->getHeader ()->getName () << " \n " );
14851489
14861490 return recognizePopcount () || recognizeAndInsertFFS () ||
1487- recognizeShiftUntilBitTest () || recognizeShiftUntilZero ();
1491+ recognizeShiftUntilBitTest () || recognizeShiftUntilZero () ||
1492+ recognizeAndInsertCtz ();
14881493}
14891494
14901495// / Check if the given conditional branch is based on the comparison between
@@ -2868,3 +2873,219 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
28682873 ++NumShiftUntilZero;
28692874 return MadeChange;
28702875}
2876+
2877+ // This function recognizes a loop that counts the number of trailing zeros
2878+ // loop:
2879+ // %count.010 = phi i32 [ %add, %while.body ], [ 0, %while.body.preheader ]
2880+ // %n.addr.09 = phi i32 [ %shr, %while.body ], [ %n, %while.body.preheader ]
2881+ // %add = add nuw nsw i32 %count.010, 1
2882+ // %shr = ashr exact i32 %n.addr.09, 1
2883+ // %0 = and i32 %n.addr.09, 2
2884+ // %cmp1 = icmp eq i32 %0, 0
2885+ // br i1 %cmp1, label %while.body, label %if.end.loopexit
2886+ static bool detectShiftUntilZeroAndOneIdiom (Loop *CurLoop, Value *&InitX,
2887+ Instruction *&CntInst,
2888+ PHINode *&CntPhi) {
2889+ BasicBlock *LoopEntry;
2890+ Value *VarX;
2891+ Instruction *DefX;
2892+
2893+ CntInst = nullptr ;
2894+ CntPhi = nullptr ;
2895+ LoopEntry = *(CurLoop->block_begin ());
2896+
2897+ // Check if the loop-back branch is in desirable form.
2898+ // "if (x == 0) goto loop-entry"
2899+ if (Value *T = matchCondition (
2900+ dyn_cast<BranchInst>(LoopEntry->getTerminator ()), LoopEntry, true )) {
2901+ DefX = dyn_cast<Instruction>(T);
2902+ } else {
2903+ LLVM_DEBUG (dbgs () << " Bad condition for branch instruction\n " );
2904+ return false ;
2905+ }
2906+
2907+ // operand compares with 2, because we are looking for "x & 2"
2908+ // which was optimized by previous passes from "(x >> 1) & 1"
2909+
2910+ if (!match (DefX, m_c_And (PatternMatch::m_Value (VarX),
2911+ PatternMatch::m_SpecificInt (2 ))))
2912+ return false ;
2913+
2914+ // check if VarX is a phi node
2915+
2916+ auto *PhiX = dyn_cast<PHINode>(VarX);
2917+
2918+ if (!PhiX || PhiX->getParent () != LoopEntry)
2919+ return false ;
2920+
2921+ Instruction *DefXRShift = nullptr ;
2922+
2923+ // check if PhiX has a shift instruction as a operand, which is a "x >> 1"
2924+
2925+ for (int i = 0 ; i < 2 ; ++i) {
2926+ if (auto *Inst = dyn_cast<Instruction>(PhiX->getOperand (i))) {
2927+ if (Inst->getOpcode () == Instruction::AShr ||
2928+ Inst->getOpcode () == Instruction::LShr) {
2929+ DefXRShift = Inst;
2930+ break ;
2931+ }
2932+ }
2933+ }
2934+
2935+ if (DefXRShift == nullptr )
2936+ return false ;
2937+
2938+ // check if the shift instruction is a "x >> 1"
2939+ auto *Shft = dyn_cast<ConstantInt>(DefXRShift->getOperand (1 ));
2940+ if (!Shft || !Shft->isOne ())
2941+ return false ;
2942+
2943+ if (DefXRShift->getOperand (0 ) != VarX)
2944+ return false ;
2945+
2946+ InitX = PhiX->getIncomingValueForBlock (CurLoop->getLoopPreheader ());
2947+
2948+ // Find the instruction which counts the trailing zeros: cnt.next = cnt + 1.
2949+ for (Instruction &Inst : llvm::make_range (
2950+ LoopEntry->getFirstNonPHI ()->getIterator (), LoopEntry->end ())) {
2951+ if (Inst.getOpcode () != Instruction::Add)
2952+ continue ;
2953+
2954+ ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand (1 ));
2955+ if (!Inc || !Inc->isOne ())
2956+ continue ;
2957+
2958+ PHINode *Phi = getRecurrenceVar (Inst.getOperand (0 ), &Inst, LoopEntry);
2959+ if (!Phi)
2960+ continue ;
2961+
2962+ CntInst = &Inst;
2963+ CntPhi = Phi;
2964+ break ;
2965+ }
2966+ if (!CntInst)
2967+ return false ;
2968+
2969+ return true ;
2970+ }
2971+
2972+ // / Recognize CTTZ idiom in a non-countable loop and convert it to countable
2973+ // / with CTTZ of variable as a trip count. If CTTZ was inserted, returns true;
2974+ // / otherwise, returns false.
2975+ // /
2976+ // int count_trailing_zeroes(uint32_t n) {
2977+ // int count = 0;
2978+ // if (n == 0){
2979+ // return 32;
2980+ // }
2981+ // while ((n & 1) == 0) {
2982+ // count += 1;
2983+ // n >>= 1;
2984+ // }
2985+ //
2986+ //
2987+ // return count;
2988+ // }
2989+ bool LoopIdiomRecognize::recognizeAndInsertCtz () {
2990+ // Give up if the loop has multiple blocks or multiple backedges.
2991+ if (CurLoop->getNumBackEdges () != 1 || CurLoop->getNumBlocks () != 1 )
2992+ return false ;
2993+
2994+ Value *InitX;
2995+ PHINode *CntPhi = nullptr ;
2996+ Instruction *CntInst = nullptr ;
2997+ // For counting trailing zeros with uncountable loop idiom, transformation is
2998+ // always profitable if IdiomCanonicalSize is 7.
2999+ const size_t IdiomCanonicalSize = 7 ;
3000+
3001+ if (!detectShiftUntilZeroAndOneIdiom (CurLoop, InitX, CntInst, CntPhi))
3002+ return false ;
3003+
3004+ BasicBlock *PH = CurLoop->getLoopPreheader ();
3005+
3006+ auto *PreCondBB = PH->getSinglePredecessor ();
3007+ if (!PreCondBB)
3008+ return false ;
3009+ auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator ());
3010+ if (!PreCondBI)
3011+ return false ;
3012+
3013+ // check that initial value is not zero and "(init & 1) == 0"
3014+ // initial value must not be zero, because it will cause infinite loop
3015+ // without this check, after replacing the loop with cttz, the counter will be
3016+ // size of int, while before the replacement the loop would have executed
3017+ // indefinitely
3018+
3019+ // match that case, where n is initial value
3020+ // entry:
3021+ // %cmp.not = icmp eq i32 %n, 0
3022+ // br i1 %cmp.not, label %cleanup, label %while.cond.preheader
3023+ //
3024+ // while.cond.preheader:
3025+ // %and5 = and i32 %n, 1
3026+ // %cmp16 = icmp eq i32 %and5, 0
3027+ // br i1 %cmp16, label %while.body.preheader, label %cleanup
3028+
3029+ Value *PreCond = matchCondition (PreCondBI, PH, true );
3030+
3031+ if (!PreCond)
3032+ return false ;
3033+
3034+ Value *InitPredX = nullptr ;
3035+ if (!match (PreCond, m_c_And (PatternMatch::m_Value (InitPredX),
3036+ PatternMatch::m_One ())) ||
3037+ InitPredX != InitX)
3038+ return false ;
3039+ auto *PrePreCondBB = PreCondBB->getSinglePredecessor ();
3040+ if (!PrePreCondBB)
3041+ return false ;
3042+ auto *PrePreCondBI = dyn_cast<BranchInst>(PrePreCondBB->getTerminator ());
3043+ if (!PrePreCondBI)
3044+ return false ;
3045+ if (matchCondition (PrePreCondBI, PreCondBB) != InitX)
3046+ return false ;
3047+
3048+ // CTTZ intrinsic always profitable after deleting the loop.
3049+ // the loop has only 7 instructions:
3050+
3051+ // @llvm.dbg doesn't count as they have no semantic effect.
3052+ auto InstWithoutDebugIt = CurLoop->getHeader ()->instructionsWithoutDebug ();
3053+ uint32_t HeaderSize =
3054+ std::distance (InstWithoutDebugIt.begin (), InstWithoutDebugIt.end ());
3055+ if (HeaderSize != IdiomCanonicalSize)
3056+ return false ;
3057+
3058+ transformLoopToCtz (PH, CntInst, CntPhi, InitX);
3059+ return true ;
3060+ }
3061+
3062+ void LoopIdiomRecognize::transformLoopToCtz (BasicBlock *Preheader,
3063+ Instruction *CntInst,
3064+ PHINode *CntPhi, Value *InitX) {
3065+ BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator ());
3066+ const DebugLoc &DL = CntInst->getDebugLoc ();
3067+
3068+ // Insert the CTTZ instruction at the end of the preheader block
3069+ IRBuilder<> Builder (PreheaderBr);
3070+ Builder.SetCurrentDebugLocation (DL);
3071+ Value *Count = createFFSIntrinsic (Builder, InitX, DL,
3072+ /* is zero poison */ true , Intrinsic::cttz);
3073+
3074+ Value *NewCount = Count;
3075+
3076+ NewCount = Builder.CreateZExtOrTrunc (NewCount, CntInst->getType ());
3077+
3078+ Value *CntInitVal = CntPhi->getIncomingValueForBlock (Preheader);
3079+ // If the counter was being incremented in the loop, add NewCount to the
3080+ // counter's initial value, but only if the initial value is not zero.
3081+ ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
3082+ if (!InitConst || !InitConst->isZero ())
3083+ NewCount = Builder.CreateAdd (NewCount, CntInitVal);
3084+
3085+ BasicBlock *Body = *(CurLoop->block_begin ());
3086+
3087+ // All the references to the original counter outside
3088+ // the loop are replaced with the NewCount
3089+ CntInst->replaceUsesOutsideBlock (NewCount, Body);
3090+ SE->forgetLoop (CurLoop);
3091+ }
0 commit comments