@@ -628,6 +628,14 @@ static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) {
628628 return Users.takeVector ();
629629}
630630
631+ static SmallVector<VPValue *> collectOperandsRecursively (VPRecipeBase *R) {
632+ SetVector<VPValue *> Operands (llvm::from_range, R->operands ());
633+ for (unsigned I = 0 ; I != Operands.size (); ++I)
634+ if (auto *Cur = Operands[I]->getDefiningRecipe ())
635+ Operands.insert_range (Cur->operands ());
636+ return Operands.takeVector ();
637+ }
638+
631639// / Legalize VPWidenPointerInductionRecipe, by replacing it with a PtrAdd
632640// / (IndStart, ScalarIVSteps (0, Step)) if only its scalar values are used, as
633641// / VPWidenPointerInductionRecipe will generate vectors only. If some users
@@ -4054,25 +4062,42 @@ VPlanTransforms::expandSCEVs(VPlan &Plan, ScalarEvolution &SE) {
40544062 return ExpandedSCEVs;
40554063}
40564064
4057- // / Returns true if \p V is VPWidenLoadRecipe or VPInterleaveRecipe that can be
4058- // / converted to a narrower recipe. \p V is used by a wide recipe that feeds a
4059- // / store interleave group at index \p Idx, \p WideMember0 is the recipe feeding
4060- // / the same interleave group at index 0. A VPWidenLoadRecipe can be narrowed to
4061- // / an index-independent load if it feeds all wide ops at all indices (\p OpV
4062- // / must be the operand at index \p OpIdx for both the recipe at lane 0, \p
4063- // / WideMember0). A VPInterleaveRecipe can be narrowed to a wide load, if \p V
4064- // / is defined at \p Idx of a load interleave group.
4065- static bool canNarrowLoad (VPWidenRecipe *WideMember0, unsigned OpIdx,
4066- VPValue *OpV, unsigned Idx) {
4067- auto *DefR = OpV->getDefiningRecipe ();
4068- if (!DefR)
4069- return WideMember0->getOperand (OpIdx) == OpV;
4070- if (auto *W = dyn_cast<VPWidenLoadRecipe>(DefR))
4071- return !W->getMask () && WideMember0->getOperand (OpIdx) == OpV;
4072-
4073- if (auto *IR = dyn_cast<VPInterleaveRecipe>(DefR))
4074- return IR->getInterleaveGroup ()->isFull () && IR->getVPValue (Idx) == OpV;
4075- return false ;
4065+ // / Returns true if the \p StoredValues of an interleave group match. It does
4066+ // / this by going through operands recursively until it hits the leaf cases:
4067+ // / VPWidenLoadRecipe, VPInterleaveRecipe, and live-ins.
4068+ static bool interleaveStoredValuesMatch (ArrayRef<VPValue *> StoredValues) {
4069+ auto *WideMember0 =
4070+ dyn_cast_or_null<VPWidenRecipe>(StoredValues[0 ]->getDefiningRecipe ());
4071+ if (!WideMember0)
4072+ return false ;
4073+ SmallVector<VPValue *> Ops0 = collectOperandsRecursively (WideMember0);
4074+ for (VPValue *ValI : StoredValues) {
4075+ auto *WideMemberI =
4076+ dyn_cast_or_null<VPWidenRecipe>(ValI->getDefiningRecipe ());
4077+ if (!WideMemberI || WideMemberI->getOpcode () != WideMember0->getOpcode ())
4078+ return false ;
4079+ SmallVector<VPValue *> OpsI = collectOperandsRecursively (WideMemberI);
4080+ if (Ops0.size () != OpsI.size ())
4081+ return false ;
4082+ for (const auto &[Op0, OpI] : zip (Ops0, OpsI)) {
4083+ auto *Def0 = Op0->getDefiningRecipe ();
4084+ auto *DefI = OpI->getDefiningRecipe ();
4085+ if (!Def0 || !DefI) {
4086+ if (Op0 != OpI)
4087+ return false ;
4088+ } else if (Def0->getVPDefID () != DefI->getVPDefID ()) {
4089+ return false ;
4090+ } else if (auto *W = dyn_cast<VPWidenLoadRecipe>(DefI)) {
4091+ if (W->isMasked () || Op0 != OpI)
4092+ return false ;
4093+ } else if (auto *IR = dyn_cast<VPInterleaveRecipe>(DefI)) {
4094+ if (!IR->getInterleaveGroup ()->isFull () ||
4095+ !equal (DefI->definedValues (), Def0->definedValues ()))
4096+ return false ;
4097+ }
4098+ }
4099+ }
4100+ return true ;
40764101}
40774102
40784103// / Returns true if \p IR is a full interleave group with factor and number of
@@ -4191,24 +4216,9 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
41914216 continue ;
41924217 }
41934218
4194- // Check if all values feeding InterleaveR are matching wide recipes, which
4195- // operands that can be narrowed.
4196- auto *WideMember0 = dyn_cast_or_null<VPWidenRecipe>(
4197- InterleaveR->getStoredValues ()[0 ]->getDefiningRecipe ());
4198- if (!WideMember0)
4219+ // Check if all values feeding InterleaveR match.
4220+ if (!interleaveStoredValuesMatch (InterleaveR->getStoredValues ()))
41994221 return ;
4200- for (const auto &[I, V] : enumerate(InterleaveR->getStoredValues ())) {
4201- auto *R = dyn_cast_or_null<VPWidenRecipe>(V->getDefiningRecipe ());
4202- if (!R || R->getOpcode () != WideMember0->getOpcode () ||
4203- R->getNumOperands () > 2 )
4204- return ;
4205- if (any_of (enumerate(R->operands ()),
4206- [WideMember0, Idx = I](const auto &P) {
4207- const auto &[OpIdx, OpV] = P;
4208- return !canNarrowLoad (WideMember0, OpIdx, OpV, Idx);
4209- }))
4210- return ;
4211- }
42124222 StoreGroups.push_back (InterleaveR);
42134223 }
42144224
@@ -4240,7 +4250,11 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
42404250 NarrowedOps.insert (RepR);
42414251 return RepR;
42424252 }
4243- auto *WideLoad = cast<VPWidenLoadRecipe>(R);
4253+ auto *WideLoad = dyn_cast<VPWidenLoadRecipe>(R);
4254+ if (!WideLoad) {
4255+ NarrowedOps.insert (V);
4256+ return V;
4257+ }
42444258
42454259 VPValue *PtrOp = WideLoad->getAddr ();
42464260 if (auto *VecPtr = dyn_cast<VPVectorPointerRecipe>(PtrOp))
0 commit comments