@@ -59,19 +59,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
5959 cl::init (false ),
6060 cl::desc(" Disable Loop Idiom Vectorize Pass." ));
6161
62+ static cl::opt<LoopIdiomVectorizeStyle>
63+ LITVecStyle (" loop-idiom-vectorize-style" , cl::Hidden,
64+ cl::desc (" The vectorization style for loop idiom transform." ),
65+ cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, " masked" ,
66+ " Use masked vector intrinsics" ),
67+ clEnumValN(LoopIdiomVectorizeStyle::Predicated,
68+ " predicated" , " Use VP intrinsics" )),
69+ cl::init(LoopIdiomVectorizeStyle::Masked));
70+
6271static cl::opt<bool >
6372 DisableByteCmp (" disable-loop-idiom-vectorize-bytecmp" , cl::Hidden,
6473 cl::init (false ),
6574 cl::desc(" Proceed with Loop Idiom Vectorize Pass, but do "
6675 " not convert byte-compare loop(s)." ));
6776
77+ static cl::opt<unsigned >
78+ ByteCmpVF (" loop-idiom-vectorize-bytecmp-vf" , cl::Hidden,
79+ cl::desc (" The vectorization factor for byte-compare patterns." ),
80+ cl::init(16 ));
81+
6882static cl::opt<bool >
6983 VerifyLoops (" loop-idiom-vectorize-verify" , cl::Hidden, cl::init(false ),
7084 cl::desc(" Verify loops generated Loop Idiom Vectorize Pass." ));
7185
7286namespace {
73-
7487class LoopIdiomVectorize {
88+ LoopIdiomVectorizeStyle VectorizeStyle;
89+ unsigned ByteCompareVF;
7590 Loop *CurLoop = nullptr ;
7691 DominatorTree *DT;
7792 LoopInfo *LI;
@@ -86,10 +101,11 @@ class LoopIdiomVectorize {
86101 BasicBlock *VectorLoopIncBlock = nullptr ;
87102
88103public:
89- explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
90- const TargetTransformInfo *TTI,
91- const DataLayout *DL)
92- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
104+ LoopIdiomVectorize (LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
105+ LoopInfo *LI, const TargetTransformInfo *TTI,
106+ const DataLayout *DL)
107+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
108+ }
93109
94110 bool run (Loop *L);
95111
@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111127 GetElementPtrInst *GEPA,
112128 GetElementPtrInst *GEPB, Value *ExtStart,
113129 Value *ExtEnd);
130+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
131+ GetElementPtrInst *GEPA,
132+ GetElementPtrInst *GEPB, Value *ExtStart,
133+ Value *ExtEnd);
114134
115135 void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116136 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128148
129149 const auto *DL = &L.getHeader ()->getDataLayout ();
130150
131- LoopIdiomVectorize LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
132- if (!LIT.run (&L))
151+ LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152+ if (LITVecStyle.getNumOccurrences ())
153+ VecStyle = LITVecStyle;
154+
155+ unsigned BCVF = ByteCompareVF;
156+ if (ByteCmpVF.getNumOccurrences ())
157+ BCVF = ByteCmpVF;
158+
159+ LoopIdiomVectorize LIV (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
160+ if (!LIV.run (&L))
133161 return PreservedAnalyses::all ();
134162
135163 return PreservedAnalyses::none ();
@@ -354,20 +382,16 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
354382 Value *PtrA = GEPA->getPointerOperand ();
355383 Value *PtrB = GEPB->getPointerOperand ();
356384
357- // At this point we know two things must be true:
358- // 1. Start <= End
359- // 2. ExtMaxLen <= MinPageSize due to the page checks.
360- // Therefore, we know that we can use a 64-bit induction variable that
361- // starts from 0 -> ExtMaxLen and it will not overflow.
362385 ScalableVectorType *PredVTy =
363- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
386+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
364387
365388 Value *InitialPred = Builder.CreateIntrinsic (
366389 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367390
368391 Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
369- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
370- /* HasNUW=*/ true , /* HasNSW=*/ true );
392+ VecLen =
393+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
394+ /* HasNUW=*/ true , /* HasNSW=*/ true );
371395
372396 Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
373397 Builder.getInt1 (false ));
@@ -385,7 +409,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
385409 LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
386410 PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
387411 VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
388- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
412+ Type *VectorLoadType =
413+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
389414 Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
390415
391416 Value *VectorLhsGep =
@@ -454,6 +479,109 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
454479 return Builder.CreateTrunc (VectorLoopRes64, ResType);
455480}
456481
482+ Value *LoopIdiomVectorize::createPredicatedFindMismatch (
483+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
484+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
485+ Type *I64Type = Builder.getInt64Ty ();
486+ Type *I32Type = Builder.getInt32Ty ();
487+ Type *ResType = I32Type;
488+ Type *LoadType = Builder.getInt8Ty ();
489+ Value *PtrA = GEPA->getPointerOperand ();
490+ Value *PtrB = GEPB->getPointerOperand ();
491+
492+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
493+ Builder.Insert (JumpToVectorLoop);
494+
495+ DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
496+ VectorLoopStartBlock}});
497+
498+ // Set up the first Vector loop block by creating the PHIs, doing the vector
499+ // loads and comparing the vectors.
500+ Builder.SetInsertPoint (VectorLoopStartBlock);
501+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
502+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
503+
504+ // Calculate AVL by subtracting the vector loop index from the trip count
505+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
506+ /* HasNSW=*/ true );
507+
508+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
509+ auto *VF = ConstantInt::get (I32Type, ByteCompareVF);
510+
511+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
512+ {I64Type}, {AVL, VF, Builder.getTrue ()});
513+ Value *GepOffset = VectorIndexPhi;
514+
515+ Value *VectorLhsGep =
516+ Builder.CreateGEP (LoadType, PtrA, GepOffset, " " , GEPA->isInBounds ());
517+ VectorType *TrueMaskTy =
518+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
519+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
520+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
521+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
522+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
523+
524+ Value *VectorRhsGep =
525+ Builder.CreateGEP (LoadType, PtrB, GepOffset, " " , GEPB->isInBounds ());
526+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
527+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
528+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
529+
530+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
531+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
532+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
533+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
534+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
535+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
536+ " mismatch.cmp" );
537+ Value *CTZ = Builder.CreateIntrinsic (
538+ Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType ()},
539+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (false ), AllTrueMask,
540+ VL});
541+ Value *MismatchFound = Builder.CreateICmpNE (CTZ, VL);
542+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
543+ VectorLoopIncBlock, MismatchFound);
544+ Builder.Insert (VectorEarlyExit);
545+
546+ DTU.applyUpdates (
547+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
548+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
549+
550+ // Increment the index counter and calculate the predicate for the next
551+ // iteration of the loop. We branch back to the start of the loop if there
552+ // is at least one active lane.
553+ Builder.SetInsertPoint (VectorLoopIncBlock);
554+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
555+ Value *NewVectorIndexPhi =
556+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
557+ /* HasNUW=*/ true , /* HasNSW=*/ true );
558+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
559+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
560+ auto *VectorLoopBranchBack =
561+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
562+ Builder.Insert (VectorLoopBranchBack);
563+
564+ DTU.applyUpdates (
565+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
566+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
567+
568+ // If we found a mismatch then we need to calculate which lane in the vector
569+ // had a mismatch and add that on to the current loop index.
570+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
571+
572+ // Add LCSSA phis for CTZ and VectorIndexPhi.
573+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
574+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
575+ auto *VectorIndexLCSSAPhi =
576+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
577+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
578+
579+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
580+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
581+ /* HasNUW=*/ true , /* HasNSW=*/ true );
582+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
583+ }
584+
457585Value *LoopIdiomVectorize::expandFindMismatch (
458586 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
459587 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -613,8 +741,22 @@ Value *LoopIdiomVectorize::expandFindMismatch(
613741 // processed in each iteration, etc.
614742 Builder.SetInsertPoint (VectorLoopPreheaderBlock);
615743
616- Value *VectorLoopRes =
617- createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
744+ // At this point we know two things must be true:
745+ // 1. Start <= End
746+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
747+ // Therefore, we know that we can use a 64-bit induction variable that
748+ // starts from 0 -> ExtMaxLen and it will not overflow.
749+ Value *VectorLoopRes = nullptr ;
750+ switch (VectorizeStyle) {
751+ case LoopIdiomVectorizeStyle::Masked:
752+ VectorLoopRes =
753+ createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
754+ break ;
755+ case LoopIdiomVectorizeStyle::Predicated:
756+ VectorLoopRes = createPredicatedFindMismatch (Builder, DTU, GEPA, GEPB,
757+ ExtStart, ExtEnd);
758+ break ;
759+ }
618760
619761 Builder.Insert (BranchInst::Create (EndBlock));
620762
0 commit comments