@@ -2287,9 +2287,13 @@ class BoUpSLP {
22872287 /// a smaller type with a truncation. We collect the values that will be
22882288 /// demoted in ToDemote and additional roots that require investigating in
22892289 /// Roots.
2290- bool collectValuesToDemote(Value *V, SmallVectorImpl<Value *> &ToDemote,
2291- SmallVectorImpl<Value *> &Roots,
2292- DenseSet<Value *> &Visited) const;
2290+ /// \param DemotedConsts list of Instruction/OperandIndex pairs that are
2291+ /// constant and to be demoted. Required to correctly identify constant nodes
2292+ /// to be demoted.
2293+ bool collectValuesToDemote(
2294+ Value *V, SmallVectorImpl<Value *> &ToDemote,
2295+ DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts,
2296+ SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const;
22932297
22942298 /// Check if the operands on the edges \p Edges of the \p UserTE allows
22952299 /// reordering (i.e. the operands can be reordered because they have only one
@@ -2352,6 +2356,9 @@ class BoUpSLP {
23522356 /// of a vector of (the same) instruction.
23532357 TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> Ops);
23542358
2359+ /// \ returns the graph entry for the \p Idx operand of the \p E entry.
2360+ const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const;
2361+
23552362 /// \returns the cost of the vectorizable entry.
23562363 InstructionCost getEntryCost(const TreeEntry *E,
23572364 ArrayRef<Value *> VectorizedVals,
@@ -3594,7 +3601,7 @@ class BoUpSLP {
35943601 /// where "width" indicates the minimum bit width and "signed" is True if the
35953602 /// value must be signed-extended, rather than zero-extended, back to its
35963603 /// original width.
3597- DenseMap<Value *, std::pair<uint64_t, bool>> MinBWs;
3604+ DenseMap<const TreeEntry *, std::pair<uint64_t, bool>> MinBWs;
35983605};
35993606
36003607} // end namespace slpvectorizer
@@ -7579,7 +7586,36 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
75797586 assert((IsFinalized || CommonMask.empty()) &&
75807587 "Shuffle construction must be finalized.");
75817588 }
7582- };
7589+ };
7590+
7591+ const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E,
7592+ unsigned Idx) const {
7593+ Value *Op = E->getOperand(Idx).front();
7594+ if (const TreeEntry *TE = getTreeEntry(Op)) {
7595+ if (find_if(E->UserTreeIndices, [&](const EdgeInfo &EI) {
7596+ return EI.EdgeIdx == Idx && EI.UserTE == E;
7597+ }) != TE->UserTreeIndices.end())
7598+ return TE;
7599+ auto MIt = MultiNodeScalars.find(Op);
7600+ if (MIt != MultiNodeScalars.end()) {
7601+ for (const TreeEntry *TE : MIt->second) {
7602+ if (find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) {
7603+ return EI.EdgeIdx == Idx && EI.UserTE == E;
7604+ }) != TE->UserTreeIndices.end())
7605+ return TE;
7606+ }
7607+ }
7608+ }
7609+ const auto *It =
7610+ find_if(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
7611+ return TE->State == TreeEntry::NeedToGather &&
7612+ find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) {
7613+ return EI.EdgeIdx == Idx && EI.UserTE == E;
7614+ }) != TE->UserTreeIndices.end();
7615+ });
7616+ assert(It != VectorizableTree.end() && "Expected vectorizable entry.");
7617+ return It->get();
7618+ }
75837619
75847620InstructionCost
75857621BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
@@ -7602,7 +7638,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
76027638
76037639 // If we have computed a smaller type for the expression, update VecTy so
76047640 // that the costs will be accurate.
7605- auto It = MinBWs.find(VL.front() );
7641+ auto It = MinBWs.find(E );
76067642 if (It != MinBWs.end()) {
76077643 ScalarTy = IntegerType::get(F->getContext(), It->second.first);
76087644 VecTy = FixedVectorType::get(ScalarTy, VL.size());
@@ -7616,16 +7652,6 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
76167652 return 0;
76177653 if (isa<InsertElementInst>(VL[0]))
76187654 return InstructionCost::getInvalid();
7619- // The gather nodes use small bitwidth only if all operands use the same
7620- // bitwidth. Otherwise - use the original one.
7621- if (It != MinBWs.end() && any_of(VL.drop_front(), [&](Value *V) {
7622- auto VIt = MinBWs.find(V);
7623- return VIt == MinBWs.end() || VIt->second.first != It->second.first ||
7624- VIt->second.second != It->second.second;
7625- })) {
7626- ScalarTy = VL.front()->getType();
7627- VecTy = FixedVectorType::get(ScalarTy, VL.size());
7628- }
76297655 ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this,
76307656 CheckedExtracts);
76317657 unsigned VF = E->getVectorFactor();
@@ -7851,7 +7877,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
78517877 if ((EI.UserTE->getOpcode() != Instruction::Select ||
78527878 EI.EdgeIdx != 0) &&
78537879 It != MinBWs.end()) {
7854- auto UserBWIt = MinBWs.find(EI.UserTE->Scalars.front() );
7880+ auto UserBWIt = MinBWs.find(EI.UserTE);
78557881 Type *UserScalarTy =
78567882 EI.UserTE->getOperand(EI.EdgeIdx).front()->getType();
78577883 if (UserBWIt != MinBWs.end())
@@ -8144,7 +8170,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
81448170 case Instruction::Trunc:
81458171 case Instruction::FPTrunc:
81468172 case Instruction::BitCast: {
8147- auto SrcIt = MinBWs.find(VL0->getOperand( 0));
8173+ auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
81488174 Type *SrcScalarTy = VL0->getOperand(0)->getType();
81498175 auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
81508176 unsigned Opcode = ShuffleOrOp;
@@ -9009,7 +9035,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
90099035 FirstUsers.emplace_back(VU, ScalarTE);
90109036 DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
90119037 VecId = FirstUsers.size() - 1;
9012- auto It = MinBWs.find(EU.Scalar );
9038+ auto It = MinBWs.find(ScalarTE );
90139039 if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
90149040 unsigned BWSz = It->second.second;
90159041 unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
@@ -9052,7 +9078,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
90529078 // for the extract and the added cost of the sign extend if needed.
90539079 auto *VecTy = FixedVectorType::get(EU.Scalar->getType(), BundleWidth);
90549080 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
9055- auto It = MinBWs.find(EU.Scalar);
9081+ auto It = MinBWs.find(getTreeEntry( EU.Scalar) );
90569082 if (It != MinBWs.end()) {
90579083 auto *MinTy = IntegerType::get(F->getContext(), It->second.first);
90589084 unsigned Extend =
@@ -9067,9 +9093,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
90679093 }
90689094 // Add reduced value cost, if resized.
90699095 if (!VectorizedVals.empty()) {
9070- auto BWIt = MinBWs.find(VectorizableTree.front()->Scalars.front ());
9096+ auto BWIt = MinBWs.find(VectorizableTree.front().get ());
90719097 if (BWIt != MinBWs.end()) {
9072- Type *DstTy = BWIt->first ->getType();
9098+ Type *DstTy = VectorizableTree.front()->Scalars.front() ->getType();
90739099 unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
90749100 unsigned Opcode = Instruction::Trunc;
90759101 if (OriginalSz < BWIt->second.first)
@@ -9430,7 +9456,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
94309456 Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
94319457 if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
94329458 continue;
9433- auto It = MinBWs.find(VTE->Scalars.front() );
9459+ auto It = MinBWs.find(VTE);
94349460 // If vectorize node is demoted - do not match.
94359461 if (It != MinBWs.end() &&
94369462 It->second.first != DL->getTypeSizeInBits(V->getType()))
@@ -11068,7 +11094,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1106811094 else if (auto *IE = dyn_cast<InsertElementInst>(VL0))
1106911095 ScalarTy = IE->getOperand(1)->getType();
1107011096 bool IsSigned = false;
11071- auto It = MinBWs.find(E->Scalars.front() );
11097+ auto It = MinBWs.find(E);
1107211098 if (It != MinBWs.end()) {
1107311099 ScalarTy = IntegerType::get(F->getContext(), It->second.first);
1107411100 IsSigned = It->second.second;
@@ -11130,7 +11156,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1113011156 Builder.SetCurrentDebugLocation(PH->getDebugLoc());
1113111157 Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true);
1113211158 if (VecTy != Vec->getType()) {
11133- assert(MinBWs.contains(PH->getIncomingValue( I)) &&
11159+ assert(MinBWs.contains(getOperandEntry(E, I)) &&
1113411160 "Expected item in MinBWs.");
1113511161 Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
1113611162 }
@@ -11167,7 +11193,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1116711193 Type *ScalarTy = Op.front()->getType();
1116811194 if (cast<VectorType>(V->getType())->getElementType() != ScalarTy) {
1116911195 assert(ScalarTy->isIntegerTy() && "Expected item in MinBWs.");
11170- std::pair<unsigned, bool> Res = MinBWs.lookup(Op.front( ));
11196+ std::pair<unsigned, bool> Res = MinBWs.lookup(getOperandEntry(E, 1 ));
1117111197 assert(Res.first > 0 && "Expected item in MinBWs.");
1117211198 V = Builder.CreateIntCast(
1117311199 V,
@@ -11342,7 +11368,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1134211368 auto *CI = cast<CastInst>(VL0);
1134311369 Instruction::CastOps VecOpcode = CI->getOpcode();
1134411370 Type *SrcScalarTy = VL0->getOperand(0)->getType();
11345- auto SrcIt = MinBWs.find(VL0->getOperand( 0));
11371+ auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
1134611372 if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
1134711373 (SrcIt != MinBWs.end() || It != MinBWs.end())) {
1134811374 // Check if the values are candidates to demote.
@@ -11383,8 +11409,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1138311409 return E->VectorizedValue;
1138411410 }
1138511411 if (L->getType() != R->getType()) {
11386- assert((MinBWs.contains(VL0->getOperand( 0)) ||
11387- MinBWs.contains(VL0->getOperand( 1))) &&
11412+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
11413+ MinBWs.contains(getOperandEntry(E, 1))) &&
1138811414 "Expected item in MinBWs.");
1138911415 L = Builder.CreateIntCast(L, VecTy, IsSigned);
1139011416 R = Builder.CreateIntCast(R, VecTy, IsSigned);
@@ -11420,8 +11446,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1142011446 return E->VectorizedValue;
1142111447 }
1142211448 if (True->getType() != False->getType()) {
11423- assert((MinBWs.contains(VL0->getOperand( 1)) ||
11424- MinBWs.contains(VL0->getOperand( 2))) &&
11449+ assert((MinBWs.contains(getOperandEntry(E, 1)) ||
11450+ MinBWs.contains(getOperandEntry(E, 2))) &&
1142511451 "Expected item in MinBWs.");
1142611452 True = Builder.CreateIntCast(True, VecTy, IsSigned);
1142711453 False = Builder.CreateIntCast(False, VecTy, IsSigned);
@@ -11488,8 +11514,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1148811514 return E->VectorizedValue;
1148911515 }
1149011516 if (LHS->getType() != RHS->getType()) {
11491- assert((MinBWs.contains(VL0->getOperand( 0)) ||
11492- MinBWs.contains(VL0->getOperand( 1))) &&
11517+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
11518+ MinBWs.contains(getOperandEntry(E, 1))) &&
1149311519 "Expected item in MinBWs.");
1149411520 LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
1149511521 RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
@@ -11725,8 +11751,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1172511751 return E->VectorizedValue;
1172611752 }
1172711753 if (LHS && RHS && LHS->getType() != RHS->getType()) {
11728- assert((MinBWs.contains(VL0->getOperand( 0)) ||
11729- MinBWs.contains(VL0->getOperand( 1))) &&
11754+ assert((MinBWs.contains(getOperandEntry(E, 0)) ||
11755+ MinBWs.contains(getOperandEntry(E, 1))) &&
1173011756 "Expected item in MinBWs.");
1173111757 LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned);
1173211758 RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned);
@@ -11962,7 +11988,7 @@ Value *BoUpSLP::vectorizeTree(
1196211988 // to the larger type.
1196311989 if (Scalar->getType() != Ex->getType())
1196411990 return Builder.CreateIntCast(Ex, Scalar->getType(),
11965- MinBWs.find(Scalar )->second.second);
11991+ MinBWs.find(E )->second.second);
1196611992 return Ex;
1196711993 }
1196811994 assert(isa<FixedVectorType>(Scalar->getType()) &&
@@ -12003,7 +12029,7 @@ Value *BoUpSLP::vectorizeTree(
1200312029 if (!UsedInserts.insert(VU).second)
1200412030 continue;
1200512031 // Need to use original vector, if the root is truncated.
12006- auto BWIt = MinBWs.find(Scalar );
12032+ auto BWIt = MinBWs.find(E );
1200712033 if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
1200812034 auto VecIt = VectorCasts.find(Scalar);
1200912035 if (VecIt == VectorCasts.end()) {
@@ -13081,22 +13107,22 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
1308113107// Determine if a value V in a vectorizable expression Expr can be demoted to a
1308213108// smaller type with a truncation. We collect the values that will be demoted
1308313109// in ToDemote and additional roots that require investigating in Roots.
13084- bool BoUpSLP::collectValuesToDemote(Value *V,
13085- SmallVectorImpl<Value *> &ToDemote,
13086- SmallVectorImpl<Value *> &Roots ,
13087- DenseSet<Value *> &Visited) const {
13110+ bool BoUpSLP::collectValuesToDemote(
13111+ Value *V, SmallVectorImpl<Value *> &ToDemote,
13112+ DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts ,
13113+ SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const {
1308813114 // We can always demote constants.
13089- if (isa<Constant>(V)) {
13090- ToDemote.push_back(V);
13115+ if (isa<Constant>(V))
1309113116 return true;
13092- }
1309313117
1309413118 // If the value is not a vectorized instruction in the expression with only
1309513119 // one use, it cannot be demoted.
1309613120 auto *I = dyn_cast<Instruction>(V);
1309713121 if (!I || !I->hasOneUse() || !getTreeEntry(I) || !Visited.insert(I).second)
1309813122 return false;
1309913123
13124+ unsigned Start = 0;
13125+ unsigned End = I->getNumOperands();
1310013126 switch (I->getOpcode()) {
1310113127
1310213128 // We can always demote truncations and extensions. Since truncations can
@@ -13118,16 +13144,21 @@ bool BoUpSLP::collectValuesToDemote(Value *V,
1311813144 case Instruction::And:
1311913145 case Instruction::Or:
1312013146 case Instruction::Xor:
13121- if (!collectValuesToDemote(I->getOperand(0), ToDemote, Roots, Visited) ||
13122- !collectValuesToDemote(I->getOperand(1), ToDemote, Roots, Visited))
13147+ if (!collectValuesToDemote(I->getOperand(0), ToDemote, DemotedConsts, Roots,
13148+ Visited) ||
13149+ !collectValuesToDemote(I->getOperand(1), ToDemote, DemotedConsts, Roots,
13150+ Visited))
1312313151 return false;
1312413152 break;
1312513153
1312613154 // We can demote selects if we can demote their true and false values.
1312713155 case Instruction::Select: {
13156+ Start = 1;
1312813157 SelectInst *SI = cast<SelectInst>(I);
13129- if (!collectValuesToDemote(SI->getTrueValue(), ToDemote, Roots, Visited) ||
13130- !collectValuesToDemote(SI->getFalseValue(), ToDemote, Roots, Visited))
13158+ if (!collectValuesToDemote(SI->getTrueValue(), ToDemote, DemotedConsts,
13159+ Roots, Visited) ||
13160+ !collectValuesToDemote(SI->getFalseValue(), ToDemote, DemotedConsts,
13161+ Roots, Visited))
1313113162 return false;
1313213163 break;
1313313164 }
@@ -13137,7 +13168,8 @@ bool BoUpSLP::collectValuesToDemote(Value *V,
1313713168 case Instruction::PHI: {
1313813169 PHINode *PN = cast<PHINode>(I);
1313913170 for (Value *IncValue : PN->incoming_values())
13140- if (!collectValuesToDemote(IncValue, ToDemote, Roots, Visited))
13171+ if (!collectValuesToDemote(IncValue, ToDemote, DemotedConsts, Roots,
13172+ Visited))
1314113173 return false;
1314213174 break;
1314313175 }
@@ -13147,6 +13179,10 @@ bool BoUpSLP::collectValuesToDemote(Value *V,
1314713179 return false;
1314813180 }
1314913181
13182+ // Gather demoted constant operands.
13183+ for (unsigned Idx : seq<unsigned>(Start, End))
13184+ if (isa<Constant>(I->getOperand(Idx)))
13185+ DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx);
1315013186 // Record the value that we can demote.
1315113187 ToDemote.push_back(V);
1315213188 return true;
@@ -13172,10 +13208,11 @@ void BoUpSLP::computeMinimumValueSizes() {
1317213208 // expression. Collect the values that can be demoted in ToDemote and
1317313209 // additional roots that require investigating in Roots.
1317413210 SmallVector<Value *, 32> ToDemote;
13211+ DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts;
1317513212 SmallVector<Value *, 4> Roots;
1317613213 for (auto *Root : TreeRoot) {
1317713214 DenseSet<Value *> Visited;
13178- if (!collectValuesToDemote(Root, ToDemote, Roots, Visited))
13215+ if (!collectValuesToDemote(Root, ToDemote, DemotedConsts, Roots, Visited))
1317913216 return;
1318013217 }
1318113218
@@ -13260,26 +13297,36 @@ void BoUpSLP::computeMinimumValueSizes() {
1326013297 // modify.
1326113298 while (!Roots.empty()) {
1326213299 DenseSet<Value *> Visited;
13263- collectValuesToDemote(Roots.pop_back_val(), ToDemote, Roots, Visited);
13300+ collectValuesToDemote(Roots.pop_back_val(), ToDemote, DemotedConsts, Roots,
13301+ Visited);
1326413302 }
1326513303
1326613304 // Finally, map the values we can demote to the maximum bit with we computed.
13267- DenseMap<const TreeEntry *, bool> Signendness;
1326813305 for (auto *Scalar : ToDemote) {
13269- bool IsSigned = true;
13270- if (auto *TE = getTreeEntry(Scalar)) {
13271- auto It = Signendness.find(TE);
13272- if (It != Signendness.end()) {
13273- IsSigned = It->second;
13274- } else {
13275- IsSigned = any_of(TE->Scalars, [&](Value *R) {
13276- KnownBits Known = computeKnownBits(R, *DL);
13277- return !Known.isNonNegative();
13278- });
13279- Signendness.try_emplace(TE, IsSigned);
13306+ auto *TE = getTreeEntry(Scalar);
13307+ assert(TE && "Expected vectorized scalar.");
13308+ if (MinBWs.contains(TE))
13309+ continue;
13310+ bool IsSigned = any_of(TE->Scalars, [&](Value *R) {
13311+ KnownBits Known = computeKnownBits(R, *DL);
13312+ return !Known.isNonNegative();
13313+ });
13314+ MinBWs.try_emplace(TE, MaxBitWidth, IsSigned);
13315+ const auto *I = cast<Instruction>(Scalar);
13316+ auto DCIt = DemotedConsts.find(I);
13317+ if (DCIt != DemotedConsts.end()) {
13318+ for (unsigned Idx : DCIt->getSecond()) {
13319+ // Check that all instructions operands are demoted.
13320+ if (all_of(TE->Scalars, [&](Value *V) {
13321+ auto SIt = DemotedConsts.find(cast<Instruction>(V));
13322+ return SIt != DemotedConsts.end() &&
13323+ is_contained(SIt->getSecond(), Idx);
13324+ })) {
13325+ const TreeEntry *CTE = getOperandEntry(TE, Idx);
13326+ MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned);
13327+ }
1328013328 }
1328113329 }
13282- MinBWs.try_emplace(Scalar, MaxBitWidth, IsSigned);
1328313330 }
1328413331}
1328513332
0 commit comments