@@ -1035,50 +1035,61 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
10351035 if (match (U, m_Select (m_Specific (&I), m_Value (), m_Value ())))
10361036 return false ;
10371037
1038- // Match against one or both scalar values being inserted into constant
1039- // vectors:
1040- // vec_op VecC0, (inselt VecC1, V1, Index)
1041- // vec_op (inselt VecC0, V0, Index), VecC1
1042- // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1043- // TODO: Deal with mismatched index constants and variable indexes?
10441038 Constant *VecC0 = nullptr , *VecC1 = nullptr ;
10451039 Value *V0 = nullptr , *V1 = nullptr ;
1046- uint64_t Index0 = 0 , Index1 = 0 ;
1047- if (!match (Ins0, m_InsertElt (m_Constant (VecC0), m_Value (V0),
1048- m_ConstantInt (Index0))) &&
1049- !match (Ins0, m_Constant (VecC0)))
1050- return false ;
1051- if (!match (Ins1, m_InsertElt (m_Constant (VecC1), m_Value (V1),
1052- m_ConstantInt (Index1))) &&
1053- !match (Ins1, m_Constant (VecC1)))
1054- return false ;
1040+ std::optional<uint64_t > Index;
1041+
1042+ // Try and match against two splatted operands first.
1043+ // vec_op (splat V0), (splat V1)
1044+ V0 = getSplatValue (Ins0);
1045+ V1 = getSplatValue (Ins1);
1046+ if (!V0 || !V1) {
1047+ // Match against one or both scalar values being inserted into constant
1048+ // vectors:
1049+ // vec_op VecC0, (inselt VecC1, V1, Index)
1050+ // vec_op (inselt VecC0, V0, Index), VecC1
1051+ // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1052+ // TODO: Deal with mismatched index constants and variable indexes?
1053+ V0 = nullptr , V1 = nullptr ;
1054+ uint64_t Index0 = 0 , Index1 = 0 ;
1055+ if (!match (Ins0, m_InsertElt (m_Constant (VecC0), m_Value (V0),
1056+ m_ConstantInt (Index0))) &&
1057+ !match (Ins0, m_Constant (VecC0)))
1058+ return false ;
1059+ if (!match (Ins1, m_InsertElt (m_Constant (VecC1), m_Value (V1),
1060+ m_ConstantInt (Index1))) &&
1061+ !match (Ins1, m_Constant (VecC1)))
1062+ return false ;
10551063
1056- bool IsConst0 = !V0;
1057- bool IsConst1 = !V1;
1058- if (IsConst0 && IsConst1)
1059- return false ;
1060- if (!IsConst0 && !IsConst1 && Index0 != Index1)
1061- return false ;
1064+ bool IsConst0 = !V0;
1065+ bool IsConst1 = !V1;
1066+ if (IsConst0 && IsConst1)
1067+ return false ;
1068+ if (!IsConst0 && !IsConst1 && Index0 != Index1)
1069+ return false ;
10621070
1063- auto *VecTy0 = cast<VectorType>(Ins0->getType ());
1064- auto *VecTy1 = cast<VectorType>(Ins1->getType ());
1065- if (VecTy0->getElementCount ().getKnownMinValue () <= Index0 ||
1066- VecTy1->getElementCount ().getKnownMinValue () <= Index1)
1067- return false ;
1071+ auto *VecTy0 = cast<VectorType>(Ins0->getType ());
1072+ auto *VecTy1 = cast<VectorType>(Ins1->getType ());
1073+ if (VecTy0->getElementCount ().getKnownMinValue () <= Index0 ||
1074+ VecTy1->getElementCount ().getKnownMinValue () <= Index1)
1075+ return false ;
10681076
1069- // Bail for single insertion if it is a load.
1070- // TODO: Handle this once getVectorInstrCost can cost for load/stores.
1071- auto *I0 = dyn_cast_or_null<Instruction>(V0);
1072- auto *I1 = dyn_cast_or_null<Instruction>(V1);
1073- if ((IsConst0 && I1 && I1->mayReadFromMemory ()) ||
1074- (IsConst1 && I0 && I0->mayReadFromMemory ()))
1075- return false ;
1077+ // Bail for single insertion if it is a load.
1078+ // TODO: Handle this once getVectorInstrCost can cost for load/stores.
1079+ auto *I0 = dyn_cast_or_null<Instruction>(V0);
1080+ auto *I1 = dyn_cast_or_null<Instruction>(V1);
1081+ if ((IsConst0 && I1 && I1->mayReadFromMemory ()) ||
1082+ (IsConst1 && I0 && I0->mayReadFromMemory ()))
1083+ return false ;
1084+
1085+ Index = IsConst0 ? Index1 : Index0;
1086+ }
10761087
1077- uint64_t Index = IsConst0 ? Index1 : Index0;
1078- Type *ScalarTy = IsConst0 ? V1->getType () : V0->getType ();
1079- Type *VecTy = I.getType ();
1088+ auto *VecTy = cast<VectorType>(I.getType ());
1089+ Type *ScalarTy = VecTy->getElementType ();
10801090 assert (VecTy->isVectorTy () &&
1081- (IsConst0 || IsConst1 || V0->getType () == V1->getType ()) &&
1091+ (isa<Constant>(Ins0) || isa<Constant>(Ins1) ||
1092+ V0->getType () == V1->getType ()) &&
10821093 (ScalarTy->isIntegerTy () || ScalarTy->isFloatingPointTy () ||
10831094 ScalarTy->isPointerTy ()) &&
10841095 " Unexpected types for insert element into binop or cmp" );
@@ -1099,29 +1110,33 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
10991110 // Get cost estimate for the insert element. This cost will factor into
11001111 // both sequences.
11011112 InstructionCost InsertCost = TTI.getVectorInstrCost (
1102- Instruction::InsertElement, VecTy, CostKind, Index);
1103- InstructionCost OldCost =
1104- (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
1105- InstructionCost NewCost = ScalarOpCost + InsertCost +
1106- (IsConst0 ? 0 : !Ins0->hasOneUse () * InsertCost) +
1107- (IsConst1 ? 0 : !Ins1->hasOneUse () * InsertCost);
1113+ Instruction::InsertElement, VecTy, CostKind, Index.value_or (0 ));
1114+ InstructionCost OldCost = (isa<Constant>(Ins0) ? 0 : InsertCost) +
1115+ (isa<Constant>(Ins1) ? 0 : InsertCost) +
1116+ VectorOpCost;
1117+ InstructionCost NewCost =
1118+ ScalarOpCost + InsertCost +
1119+ (isa<Constant>(Ins0) ? 0 : !Ins0->hasOneUse () * InsertCost) +
1120+ (isa<Constant>(Ins1) ? 0 : !Ins1->hasOneUse () * InsertCost);
11081121
11091122 // We want to scalarize unless the vector variant actually has lower cost.
11101123 if (OldCost < NewCost || !NewCost.isValid ())
11111124 return false ;
11121125
11131126 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
11141127 // inselt NewVecC, (scalar_op V0, V1), Index
1128+ //
1129+ // vec_op (splat V0), (splat V1) --> splat (scalar_op V0, V1)
11151130 if (IsCmp)
11161131 ++NumScalarCmp;
11171132 else
11181133 ++NumScalarBO;
11191134
11201135 // For constant cases, extract the scalar element, this should constant fold.
1121- if (IsConst0 )
1122- V0 = ConstantExpr::getExtractElement (VecC0, Builder.getInt64 (Index));
1123- if (IsConst1 )
1124- V1 = ConstantExpr::getExtractElement (VecC1, Builder.getInt64 (Index));
1136+ if (Index && isa<Constant>(Ins0) )
1137+ V0 = ConstantExpr::getExtractElement (VecC0, Builder.getInt64 (* Index));
1138+ if (Index && isa<Constant>(Ins1) )
1139+ V1 = ConstantExpr::getExtractElement (VecC1, Builder.getInt64 (* Index));
11251140
11261141 Value *Scalar =
11271142 IsCmp ? Builder.CreateCmp (Pred, V0, V1)
@@ -1134,12 +1149,16 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11341149 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
11351150 ScalarInst->copyIRFlags (&I);
11361151
1137- // Fold the vector constants in the original vectors into a new base vector.
1138- Value *NewVecC =
1139- IsCmp ? Builder.CreateCmp (Pred, VecC0, VecC1)
1140- : Builder.CreateBinOp ((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1141- Value *Insert = Builder.CreateInsertElement (NewVecC, Scalar, Index);
1142- replaceValue (I, *Insert);
1152+ Value *Result;
1153+ if (Index) {
1154+ // Fold the vector constants in the original vectors into a new base vector.
1155+ Value *NewVecC = IsCmp ? Builder.CreateCmp (Pred, VecC0, VecC1)
1156+ : Builder.CreateBinOp ((Instruction::BinaryOps)Opcode,
1157+ VecC0, VecC1);
1158+ Result = Builder.CreateInsertElement (NewVecC, Scalar, *Index);
1159+ } else
1160+ Result = Builder.CreateVectorSplat (VecTy->getElementCount (), Scalar);
1161+ replaceValue (I, *Result);
11431162 return true ;
11441163}
11451164
0 commit comments