@@ -12043,6 +12043,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1204312043 /// Adds 2 input vectors and the mask for their shuffling.
1204412044 void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
1204512045 assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors.");
12046+ assert(isa<FixedVectorType>(V1->getType()) &&
12047+ isa<FixedVectorType>(V2->getType()) &&
12048+ "castToScalarTyElem expects V1 and V2 to be FixedVectorType");
1204612049 V1 = castToScalarTyElem(V1);
1204712050 V2 = castToScalarTyElem(V2);
1204812051 if (InVectors.empty()) {
@@ -12072,22 +12075,18 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1207212075 }
1207312076 /// Adds another one input vector and the mask for the shuffling.
1207412077 void add(Value *V1, ArrayRef<int> Mask, bool = false) {
12078+ assert(isa<FixedVectorType>(V1->getType()) &&
12079+ "castToScalarTyElem expects V1 to be FixedVectorType");
1207512080 V1 = castToScalarTyElem(V1);
1207612081 if (InVectors.empty()) {
12077- if (!isa<FixedVectorType>(V1->getType())) {
12078- V1 = createShuffle(V1, nullptr, CommonMask);
12079- CommonMask.assign(Mask.size(), PoisonMaskElem);
12080- transformMaskAfterShuffle(CommonMask, Mask);
12081- }
1208212082 InVectors.push_back(V1);
1208312083 CommonMask.assign(Mask.begin(), Mask.end());
1208412084 return;
1208512085 }
1208612086 const auto *It = find(InVectors, V1);
1208712087 if (It == InVectors.end()) {
1208812088 if (InVectors.size() == 2 ||
12089- InVectors.front()->getType() != V1->getType() ||
12090- !isa<FixedVectorType>(V1->getType())) {
12089+ InVectors.front()->getType() != V1->getType()) {
1209112090 Value *V = InVectors.front();
1209212091 if (InVectors.size() == 2) {
1209312092 V = createShuffle(InVectors.front(), InVectors.back(), CommonMask);
@@ -12121,9 +12120,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1212112120 break;
1212212121 }
1212312122 }
12124- int VF = CommonMask.size();
12125- if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
12126- VF = FTy->getNumElements();
12123+ int VF = cast<FixedVectorType>(V1->getType())->getNumElements();
1212712124 for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
1212812125 if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
1212912126 CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
0 commit comments