@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
22032203 Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
22042204}
22052205
2206+ static void getPartialReductionInstrChain (Instruction *Instr, SmallVector<Value*, 4 > &Chain) {
2207+ Instruction *Mul = cast<Instruction>(Instr->getOperand (0 ));
2208+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand (0 ));
2209+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand (1 ));
2210+
2211+ Chain.push_back (Mul);
2212+ Chain.push_back (Ext0);
2213+ Chain.push_back (Ext1);
2214+ Chain.push_back (Instr->getOperand (1 ));
2215+ }
2216+
2217+
2218+ // / @param Instr The root instruction to scan
2219+ static bool isInstrPartialReduction (Instruction *Instr) {
2220+ Value *ExpectedPhi;
2221+ Value *A, *B;
2222+ Value *InductionA, *InductionB;
2223+
2224+ using namespace llvm ::PatternMatch;
2225+ auto Pattern = m_Add (
2226+ m_OneUse (m_Mul (
2227+ m_OneUse (m_ZExt (
2228+ m_OneUse (m_Load (
2229+ m_GEP (
2230+ m_Value (A),
2231+ m_Value (InductionA)))))),
2232+ m_OneUse (m_ZExt (
2233+ m_OneUse (m_Load (
2234+ m_GEP (
2235+ m_Value (B),
2236+ m_Value (InductionB))))))
2237+ )), m_Value (ExpectedPhi));
2238+
2239+ bool Matches = match (Instr, Pattern);
2240+
2241+ if (!Matches)
2242+ return false ;
2243+
2244+ // Check that the two induction variable uses are to the same induction variable
2245+ if (InductionA != InductionB) {
2246+ LLVM_DEBUG (dbgs () << " Loop uses different induction variables for each input variable, cannot create a partial reduction.\n " );
2247+ return false ;
2248+ }
2249+
2250+ Instruction *Mul = cast<Instruction>(Instr->getOperand (0 ));
2251+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand (0 ));
2252+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand (1 ));
2253+
2254+ // Check that the extends extend to i32
2255+ if (!Ext0->getType ()->isIntegerTy (32 ) || !Ext1->getType ()->isIntegerTy (32 )) {
2256+ LLVM_DEBUG (dbgs () << " Extends don't extend to the correct width, cannot create a partial reduction.\n " );
2257+ return false ;
2258+ }
2259+
2260+ // Check that the loads are loading i8
2261+ LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand (0 ));
2262+ LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand (0 ));
2263+ if (!Load0->getType ()->isIntegerTy (8 ) || !Load1->getType ()->isIntegerTy (8 )) {
2264+ LLVM_DEBUG (dbgs () << " Loads don't load the correct width, cannot create a partial reduction\n " );
2265+ return false ;
2266+ }
2267+
2268+ // Check that the add feeds into ExpectedPhi
2269+ PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
2270+ if (!PhiNode) {
2271+ LLVM_DEBUG (dbgs () << " Expected Phi node was not a phi, cannot create a partial reduction.\n " );
2272+ return false ;
2273+ }
2274+
2275+ // Check that the first phi value is a zero initializer
2276+ ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue (0 ));
2277+ if (!ZeroInit || !ZeroInit->isZero ()) {
2278+ LLVM_DEBUG (dbgs () << " First PHI value is not a constant zero, cannot create a partial reduction.\n " );
2279+ return false ;
2280+ }
2281+
2282+ // Check that the second phi value is the instruction we're looking at
2283+ Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue (1 ));
2284+ if (!MaybeAdd || MaybeAdd != Instr) {
2285+ LLVM_DEBUG (dbgs () << " Second PHI value is not the root add, cannot create a partial reduction.\n " );
2286+ return false ;
2287+ }
2288+
2289+ return true ;
2290+ }
2291+
22062292// Return true if \p OuterLp is an outer loop annotated with hints for explicit
22072293// vectorization. The loop needs to be annotated with #pragma omp simd
22082294// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
50845170 return false ;
50855171 }
50865172
5173+ // Prevent epilogue vectorization if a partial reduction is involved
5174+ // TODO Is there a cleaner way to check this?
5175+ if (any_of (Legal->getReductionVars (), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
5176+ return isInstrPartialReduction (Reduction.second .getLoopExitInstr ());
5177+ }))
5178+ return false ;
5179+
50875180 // Epilogue vectorization code has not been auditted to ensure it handles
50885181 // non-latch exits properly. It may be fine, but it needs auditted and
50895182 // tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
71827275 const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts ();
71837276 VecValuesToIgnore.insert (Casts.begin (), Casts.end ());
71847277 }
7278+
7279+ // Ignore any values that we know will be flattened
7280+ for (auto Reduction : this ->Legal ->getReductionVars ()) {
7281+ auto &Recurrence = Reduction.second ;
7282+ if (isInstrPartialReduction (Recurrence.getLoopExitInstr ())) {
7283+ SmallVector<Value*, 4 > PartialReductionValues;
7284+ getPartialReductionInstrChain (Recurrence.getLoopExitInstr (), PartialReductionValues);
7285+ ValuesToIgnore.insert (PartialReductionValues.begin (), PartialReductionValues.end ());
7286+ VecValuesToIgnore.insert (PartialReductionValues.begin (), PartialReductionValues.end ());
7287+ }
7288+ }
71857289}
71867290
71877291void LoopVectorizationCostModel::collectInLoopReductions () {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
85368640 *CI);
85378641 }
85388642
8643+ if (auto *PartialReduce = tryToCreatePartialReduction (Range, Instr, Operands))
8644+ return PartialReduce;
8645+
85398646 return tryToWiden (Instr, Operands, VPBB);
85408647}
85418648
8649+ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction (
8650+ VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
8651+
8652+ if (isInstrPartialReduction (Instr)) {
8653+ auto EC = ElementCount::getScalable (16 );
8654+ if (std::find (Range.begin (), Range.end (), EC) == Range.end ())
8655+ return nullptr ;
8656+ return new VPPartialReductionRecipe (*Instr, make_range (Operands.begin (), Operands.end ()));
8657+ }
8658+ return nullptr ;
8659+ }
8660+
85428661void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
85438662 ElementCount MaxVF) {
85448663 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
87468865 VPBB->appendRecipe (Recipe);
87478866 }
87488867
8868+ for (auto &Recipe : *VPBB)
8869+ Recipe.postInsertionOp ();
8870+
87498871 VPBlockUtils::insertBlockAfter (new VPBasicBlock (), VPBB);
87508872 VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor ());
87518873 }
0 commit comments