@@ -7135,7 +7135,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
7135
7135
// into a vector and can be represented as a permutation elements in a
7136
7136
// single input vector or of 2 input vectors.
7137
7137
Cost += computeExtractCost(VL, Mask, ShuffleKind);
7138
- InVectors.assign(1, E);
7139
7138
return VecBase;
7140
7139
}
7141
7140
void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef<int> Mask) {
@@ -7146,18 +7145,57 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
7146
7145
CommonMask.assign(Mask.begin(), Mask.end());
7147
7146
InVectors.assign(1, E1);
7148
7147
}
7149
- void gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
7148
+ /// Adds another one input vector and the mask for the shuffling.
7149
+ void add(Value *V1, ArrayRef<int> Mask) {
7150
+ assert(CommonMask.empty() && InVectors.empty() &&
7151
+ "Expected empty input mask/vectors.");
7152
+ CommonMask.assign(Mask.begin(), Mask.end());
7153
+ InVectors.assign(1, V1);
7154
+ }
7155
+ Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) {
7150
7156
Cost += getBuildVectorCost(VL, Root);
7151
7157
if (!Root) {
7152
7158
assert(InVectors.empty() && "Unexpected input vectors for buildvector.");
7153
7159
// FIXME: Need to find a way to avoid use of getNullValue here.
7154
- InVectors.assign(1, Constant::getNullValue(FixedVectorType::get(
7155
- VL.front()->getType(), VL.size())));
7160
+ SmallVector<Constant *> Vals;
7161
+ for (Value *V : VL) {
7162
+ if (isa<UndefValue>(V)) {
7163
+ Vals.push_back(cast<Constant>(V));
7164
+ continue;
7165
+ }
7166
+ Vals.push_back(Constant::getNullValue(V->getType()));
7167
+ }
7168
+ return ConstantVector::get(Vals);
7156
7169
}
7170
+ return ConstantVector::getSplat(
7171
+ ElementCount::getFixed(VL.size()),
7172
+ Constant::getNullValue(VL.front()->getType()));
7157
7173
}
7158
7174
/// Finalize emission of the shuffles.
7159
- InstructionCost finalize(ArrayRef<int> ExtMask) {
7175
+ InstructionCost
7176
+ finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
7177
+ function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
7160
7178
IsFinalized = true;
7179
+ if (Action) {
7180
+ const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
7181
+ if (InVectors.size() == 2) {
7182
+ Cost += createShuffle(Vec, InVectors.back(), CommonMask);
7183
+ InVectors.pop_back();
7184
+ } else {
7185
+ Cost += createShuffle(Vec, nullptr, CommonMask);
7186
+ }
7187
+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
7188
+ if (CommonMask[Idx] != PoisonMaskElem)
7189
+ CommonMask[Idx] = Idx;
7190
+ assert(VF > 0 &&
7191
+ "Expected vector length for the final value before action.");
7192
+ Value *V = Vec.dyn_cast<Value *>();
7193
+ if (!Vec.isNull() && !V)
7194
+ V = Constant::getNullValue(FixedVectorType::get(
7195
+ Vec.get<const TreeEntry *>()->Scalars.front()->getType(),
7196
+ CommonMask.size()));
7197
+ Action(V, CommonMask);
7198
+ }
7161
7199
::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true);
7162
7200
if (CommonMask.empty())
7163
7201
return Cost;
@@ -7291,18 +7329,31 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
7291
7329
Estimator.add(Entries.front(), Mask);
7292
7330
else
7293
7331
Estimator.add(Entries.front(), Entries.back(), Mask);
7294
- Estimator.gather(
7295
- GatheredScalars,
7296
- Constant::getNullValue(FixedVectorType::get(
7297
- GatheredScalars.front()->getType(), GatheredScalars.size())));
7298
- return Estimator.finalize(E->ReuseShuffleIndices);
7299
- }
7300
- Estimator.gather(
7301
- GatheredScalars,
7302
- VL.equals(GatheredScalars)
7303
- ? nullptr
7304
- : Constant::getNullValue(FixedVectorType::get(
7305
- GatheredScalars.front()->getType(), GatheredScalars.size())));
7332
+ if (all_of(GatheredScalars, PoisonValue ::classof))
7333
+ return Estimator.finalize(E->ReuseShuffleIndices);
7334
+ return Estimator.finalize(
7335
+ E->ReuseShuffleIndices, E->Scalars.size(),
7336
+ [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
7337
+ Vec = Estimator.gather(GatheredScalars,
7338
+ Constant::getNullValue(FixedVectorType::get(
7339
+ GatheredScalars.front()->getType(),
7340
+ GatheredScalars.size())));
7341
+ });
7342
+ }
7343
+ if (!all_of(GatheredScalars, PoisonValue::classof)) {
7344
+ auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size());
7345
+ bool SameGathers = VL.equals(Gathers);
7346
+ Value *BV = Estimator.gather(
7347
+ Gathers, SameGathers ? nullptr
7348
+ : Constant::getNullValue(FixedVectorType::get(
7349
+ GatheredScalars.front()->getType(),
7350
+ GatheredScalars.size())));
7351
+ SmallVector<int> ReuseMask(Gathers.size(), PoisonMaskElem);
7352
+ std::iota(ReuseMask.begin(), ReuseMask.end(), 0);
7353
+ Estimator.add(BV, ReuseMask);
7354
+ }
7355
+ if (ExtractShuffle)
7356
+ Estimator.add(E, std::nullopt);
7306
7357
return Estimator.finalize(E->ReuseShuffleIndices);
7307
7358
}
7308
7359
InstructionCost CommonCost = 0;
0 commit comments