@@ -7273,6 +7273,33 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72737273 BypassBlock, MainResumePhi->getIncomingValueForBlock (BypassBlock));
72747274}
72757275
7276+ // / Add branch weight metadata, if the \p Plan's middle block is terminated by a
7277+ // / BranchOnCond recipe.
7278+ static void addBranchWeightToMiddleTerminator (VPlan &Plan, ElementCount VF,
7279+ Loop *OrigLoop) {
7280+ // 4. Adjust branch weight of the branch in the middle block.
7281+ Instruction *LatchTerm = OrigLoop->getLoopLatch ()->getTerminator ();
7282+ if (!hasBranchWeightMD (*LatchTerm))
7283+ return ;
7284+
7285+ VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock ();
7286+ auto *MiddleTerm =
7287+ dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator ());
7288+ // Only add branch metadata if there is a (conditional) terminator.
7289+ if (!MiddleTerm)
7290+ return ;
7291+
7292+ assert (MiddleTerm->getOpcode () == VPInstruction::BranchOnCond &&
7293+ " must have a BranchOnCond" );
7294+ // Assume that `Count % VectorTripCount` is equally distributed.
7295+ unsigned TripCount = Plan.getUF () * VF.getKnownMinValue ();
7296+ assert (TripCount > 0 && " trip count should not be zero" );
7297+ MDBuilder MDB (LatchTerm->getContext ());
7298+ MDNode *BranchWeights =
7299+ MDB.createBranchWeights ({1 , TripCount - 1 }, /* IsExpected=*/ false );
7300+ MiddleTerm->addMetadata (LLVMContext::MD_prof, BranchWeights);
7301+ }
7302+
72767303DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan (
72777304 ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
72787305 InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
@@ -7295,11 +7322,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
72957322
72967323 VPlanTransforms::convertToConcreteRecipes (BestVPlan,
72977324 *Legal->getWidestInductionType ());
7298- // Retrieve and store the middle block before dissolving regions. Regions are
7299- // dissolved after optimizing for VF and UF, which completely removes unneeded
7300- // loop regions first.
7301- VPBasicBlock *MiddleVPBB =
7302- BestVPlan.getVectorLoopRegion () ? BestVPlan.getMiddleBlock () : nullptr ;
7325+
7326+ addBranchWeightToMiddleTerminator (BestVPlan, BestVF, OrigLoop);
73037327 VPlanTransforms::dissolveLoopRegions (BestVPlan);
73047328 // Perform the actual loop transformation.
73057329 VPTransformState State (&TTI, BestVF, LI, DT, ILV.AC , ILV.Builder , &BestVPlan,
@@ -7442,20 +7466,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
74427466
74437467 ILV.printDebugTracesAtEnd ();
74447468
7445- // 4. Adjust branch weight of the branch in the middle block.
7446- if (HeaderVPBB) {
7447- auto *MiddleTerm =
7448- cast<BranchInst>(State.CFG .VPBB2IRBB [MiddleVPBB]->getTerminator ());
7449- if (MiddleTerm->isConditional () &&
7450- hasBranchWeightMD (*OrigLoop->getLoopLatch ()->getTerminator ())) {
7451- // Assume that `Count % VectorTripCount` is equally distributed.
7452- unsigned TripCount = BestVPlan.getUF () * State.VF .getKnownMinValue ();
7453- assert (TripCount > 0 && " trip count should not be zero" );
7454- const uint32_t Weights[] = {1 , TripCount - 1 };
7455- setBranchWeights (*MiddleTerm, Weights, /* IsExpected=*/ false );
7456- }
7457- }
7458-
74597469 return ExpandedSCEVs;
74607470}
74617471
0 commit comments