@@ -47,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
4747STATISTIC (NumVecBO, " Number of vector binops formed" );
4848STATISTIC (NumVecCmpBO, " Number of vector compare + binop formed" );
4949STATISTIC (NumShufOfBitcast, " Number of shuffles moved after bitcast" );
50- STATISTIC (NumScalarBO , " Number of scalar binops formed" );
50+ STATISTIC (NumScalarOps , " Number of scalar unary + binary ops formed" );
5151STATISTIC (NumScalarCmp, " Number of scalar compares formed" );
5252STATISTIC (NumScalarIntrinsic, " Number of scalar intrinsic calls formed" );
5353
@@ -114,7 +114,7 @@ class VectorCombine {
114114 bool foldInsExtBinop (Instruction &I);
115115 bool foldInsExtVectorToShuffle (Instruction &I);
116116 bool foldBitcastShuffle (Instruction &I);
117- bool scalarizeBinopOrCmp (Instruction &I);
117+ bool scalarizeOpOrCmp (Instruction &I);
118118 bool scalarizeVPIntrinsic (Instruction &I);
119119 bool foldExtractedCmps (Instruction &I);
120120 bool foldBinopOfReductions (Instruction &I);
@@ -1018,91 +1018,90 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
10181018 return true ;
10191019}
10201020
1021- // / Match a vector binop, compare or binop-like intrinsic with at least one
1022- // / inserted scalar operand and convert to scalar binop /cmp/intrinsic followed
1021+ // / Match a vector op/ compare/ intrinsic with at least one
1022+ // / inserted scalar operand and convert to scalar op /cmp/intrinsic followed
10231023// / by insertelement.
1024- bool VectorCombine::scalarizeBinopOrCmp (Instruction &I) {
1025- CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1026- Value *Ins0, *Ins1;
1027- if (!match (&I, m_BinOp (m_Value (Ins0), m_Value (Ins1))) &&
1028- !match (&I, m_Cmp (Pred, m_Value (Ins0), m_Value (Ins1)))) {
1029- // TODO: Allow unary and ternary intrinsics
1030- // TODO: Allow intrinsics with different argument types
1031- // TODO: Allow intrinsics with scalar arguments
1032- if (auto *II = dyn_cast<IntrinsicInst>(&I);
1033- II && II->arg_size () == 2 &&
1034- isTriviallyVectorizable (II->getIntrinsicID ()) &&
1035- all_of (II->args (),
1036- [&II](Value *Arg) { return Arg->getType () == II->getType (); })) {
1037- Ins0 = II->getArgOperand (0 );
1038- Ins1 = II->getArgOperand (1 );
1039- } else {
1040- return false ;
1041- }
1042- }
1024+ bool VectorCombine::scalarizeOpOrCmp (Instruction &I) {
1025+ auto *UO = dyn_cast<UnaryOperator>(&I);
1026+ auto *BO = dyn_cast<BinaryOperator>(&I);
1027+ auto *CI = dyn_cast<CmpInst>(&I);
1028+ auto *II = dyn_cast<IntrinsicInst>(&I);
1029+ if (!UO && !BO && !CI && !II)
1030+ return false ;
1031+
1032+ // TODO: Allow intrinsics with different argument types
1033+ // TODO: Allow intrinsics with scalar arguments
1034+ if (II && (!isTriviallyVectorizable (II->getIntrinsicID ()) ||
1035+ !all_of (II->args (), [&II](Value *Arg) {
1036+ return Arg->getType () == II->getType ();
1037+ })))
1038+ return false ;
10431039
10441040 // Do not convert the vector condition of a vector select into a scalar
10451041 // condition. That may cause problems for codegen because of differences in
10461042 // boolean formats and register-file transfers.
10471043 // TODO: Can we account for that in the cost model?
1048- if (isa<CmpInst>(I) )
1044+ if (CI )
10491045 for (User *U : I.users ())
10501046 if (match (U, m_Select (m_Specific (&I), m_Value (), m_Value ())))
10511047 return false ;
10521048
1053- // Match against one or both scalar values being inserted into constant
1054- // vectors:
1055- // vec_op VecC0, (inselt VecC1, V1, Index)
1056- // vec_op (inselt VecC0, V0, Index), VecC1
1057- // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1058- // TODO: Deal with mismatched index constants and variable indexes?
1059- Constant *VecC0 = nullptr , *VecC1 = nullptr ;
1060- Value *V0 = nullptr , *V1 = nullptr ;
1061- uint64_t Index0 = 0 , Index1 = 0 ;
1062- if (!match (Ins0, m_InsertElt (m_Constant (VecC0), m_Value (V0),
1063- m_ConstantInt (Index0))) &&
1064- !match (Ins0, m_Constant (VecC0)))
1065- return false ;
1066- if (!match (Ins1, m_InsertElt (m_Constant (VecC1), m_Value (V1),
1067- m_ConstantInt (Index1))) &&
1068- !match (Ins1, m_Constant (VecC1)))
1069- return false ;
1070-
1071- bool IsConst0 = !V0;
1072- bool IsConst1 = !V1;
1073- if (IsConst0 && IsConst1)
1074- return false ;
1075- if (!IsConst0 && !IsConst1 && Index0 != Index1)
1076- return false ;
1049+ // Match constant vectors or scalars being inserted into constant vectors:
1050+ // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1051+ SmallVector<Constant *> VecCs;
1052+ SmallVector<Value *> ScalarOps;
1053+ std::optional<uint64_t > Index;
1054+
1055+ auto Ops = II ? II->args () : I.operand_values ();
1056+ for (Value *Op : Ops) {
1057+ Constant *VecC;
1058+ Value *V;
1059+ uint64_t InsIdx = 0 ;
1060+ VectorType *OpTy = cast<VectorType>(Op->getType ());
1061+ if (match (Op, m_InsertElt (m_Constant (VecC), m_Value (V),
1062+ m_ConstantInt (InsIdx)))) {
1063+ // Bail if any inserts are out of bounds.
1064+ if (OpTy->getElementCount ().getKnownMinValue () <= InsIdx)
1065+ return false ;
1066+ // All inserts must have the same index.
1067+ // TODO: Deal with mismatched index constants and variable indexes?
1068+ if (!Index)
1069+ Index = InsIdx;
1070+ else if (InsIdx != *Index)
1071+ return false ;
1072+ VecCs.push_back (VecC);
1073+ ScalarOps.push_back (V);
1074+ } else if (match (Op, m_Constant (VecC))) {
1075+ VecCs.push_back (VecC);
1076+ ScalarOps.push_back (nullptr );
1077+ } else {
1078+ return false ;
1079+ }
1080+ }
10771081
1078- auto *VecTy0 = cast<VectorType>(Ins0->getType ());
1079- auto *VecTy1 = cast<VectorType>(Ins1->getType ());
1080- if (VecTy0->getElementCount ().getKnownMinValue () <= Index0 ||
1081- VecTy1->getElementCount ().getKnownMinValue () <= Index1)
1082+ // Bail if all operands are constant.
1083+ if (!Index.has_value ())
10821084 return false ;
10831085
1084- uint64_t Index = IsConst0 ? Index1 : Index0;
1085- Type *ScalarTy = IsConst0 ? V1->getType () : V0->getType ();
1086- Type *VecTy = I.getType ();
1086+ VectorType *VecTy = cast<VectorType>(I.getType ());
1087+ Type *ScalarTy = VecTy->getScalarType ();
10871088 assert (VecTy->isVectorTy () &&
1088- (IsConst0 || IsConst1 || V0->getType () == V1->getType ()) &&
10891089 (ScalarTy->isIntegerTy () || ScalarTy->isFloatingPointTy () ||
10901090 ScalarTy->isPointerTy ()) &&
10911091 " Unexpected types for insert element into binop or cmp" );
10921092
10931093 unsigned Opcode = I.getOpcode ();
10941094 InstructionCost ScalarOpCost, VectorOpCost;
1095- if (isa<CmpInst>(I) ) {
1096- CmpInst::Predicate Pred = cast<CmpInst>(I). getPredicate ();
1095+ if (CI ) {
1096+ CmpInst::Predicate Pred = CI-> getPredicate ();
10971097 ScalarOpCost = TTI.getCmpSelInstrCost (
10981098 Opcode, ScalarTy, CmpInst::makeCmpResultType (ScalarTy), Pred, CostKind);
10991099 VectorOpCost = TTI.getCmpSelInstrCost (
11001100 Opcode, VecTy, CmpInst::makeCmpResultType (VecTy), Pred, CostKind);
1101- } else if (isa<BinaryOperator>(I) ) {
1101+ } else if (UO || BO ) {
11021102 ScalarOpCost = TTI.getArithmeticInstrCost (Opcode, ScalarTy, CostKind);
11031103 VectorOpCost = TTI.getArithmeticInstrCost (Opcode, VecTy, CostKind);
11041104 } else {
1105- auto *II = cast<IntrinsicInst>(&I);
11061105 IntrinsicCostAttributes ScalarICA (
11071106 II->getIntrinsicID (), ScalarTy,
11081107 SmallVector<Type *>(II->arg_size (), ScalarTy));
@@ -1115,56 +1114,59 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11151114
11161115 // Fold the vector constants in the original vectors into a new base vector to
11171116 // get more accurate cost modelling.
1118- Value *NewVecC;
1119- if (isa<CmpInst>(I))
1120- NewVecC = ConstantFoldCompareInstOperands (Pred, VecC0, VecC1, *DL);
1121- else if (isa<BinaryOperator>(I))
1122- NewVecC = ConstantFoldBinaryOpOperands ((Instruction::BinaryOps)Opcode,
1123- VecC0, VecC1, *DL);
1124- else
1125- NewVecC = ConstantFoldBinaryIntrinsic (
1126- cast<IntrinsicInst>(I).getIntrinsicID (), VecC0, VecC1, I.getType (), &I);
1117+ Value *NewVecC = nullptr ;
1118+ if (CI)
1119+ NewVecC = ConstantFoldCompareInstOperands (CI->getPredicate (), VecCs[0 ],
1120+ VecCs[1 ], *DL);
1121+ else if (UO)
1122+ NewVecC = ConstantFoldUnaryOpOperand (Opcode, VecCs[0 ], *DL);
1123+ else if (BO)
1124+ NewVecC = ConstantFoldBinaryOpOperands (Opcode, VecCs[0 ], VecCs[1 ], *DL);
1125+ else if (II->arg_size () == 2 )
1126+ NewVecC = ConstantFoldBinaryIntrinsic (II->getIntrinsicID (), VecCs[0 ],
1127+ VecCs[1 ], II->getType (), II);
11271128
11281129 // Get cost estimate for the insert element. This cost will factor into
11291130 // both sequences.
1130- InstructionCost InsertCostNewVecC = TTI.getVectorInstrCost (
1131- Instruction::InsertElement, VecTy, CostKind, Index, NewVecC);
1132- InstructionCost InsertCostV0 = TTI.getVectorInstrCost (
1133- Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
1134- InstructionCost InsertCostV1 = TTI.getVectorInstrCost (
1135- Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
1136- InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
1137- (IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
1138- InstructionCost NewCost = ScalarOpCost + InsertCostNewVecC +
1139- (IsConst0 ? 0 : !Ins0->hasOneUse () * InsertCostV0) +
1140- (IsConst1 ? 0 : !Ins1->hasOneUse () * InsertCostV1);
1131+ InstructionCost OldCost = VectorOpCost;
1132+ InstructionCost NewCost =
1133+ ScalarOpCost + TTI.getVectorInstrCost (Instruction::InsertElement, VecTy,
1134+ CostKind, *Index, NewVecC);
1135+ for (auto [Op, VecC, Scalar] : zip (Ops, VecCs, ScalarOps)) {
1136+ if (!Scalar)
1137+ continue ;
1138+ InstructionCost InsertCost = TTI.getVectorInstrCost (
1139+ Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1140+ OldCost += InsertCost;
1141+ NewCost += !Op->hasOneUse () * InsertCost;
1142+ }
1143+
11411144 // We want to scalarize unless the vector variant actually has lower cost.
11421145 if (OldCost < NewCost || !NewCost.isValid ())
11431146 return false ;
11441147
11451148 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
11461149 // inselt NewVecC, (scalar_op V0, V1), Index
1147- if (isa<CmpInst>(I) )
1150+ if (CI )
11481151 ++NumScalarCmp;
1149- else if (isa<BinaryOperator>(I) )
1150- ++NumScalarBO ;
1151- else if (isa<IntrinsicInst>(I))
1152+ else if (UO || BO )
1153+ ++NumScalarOps ;
1154+ else
11521155 ++NumScalarIntrinsic;
11531156
11541157 // For constant cases, extract the scalar element, this should constant fold.
1155- if (IsConst0 )
1156- V0 = ConstantExpr::getExtractElement (VecC0, Builder. getInt64 (Index));
1157- if (IsConst1)
1158- V1 = ConstantExpr::getExtractElement (VecC1 , Builder.getInt64 (Index));
1158+ for ( auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs) )
1159+ if (!Scalar)
1160+ ScalarOps[OpIdx] = ConstantExpr::getExtractElement (
1161+ cast<Constant>(VecC) , Builder.getInt64 (* Index));
11591162
11601163 Value *Scalar;
1161- if (isa<CmpInst>(I) )
1162- Scalar = Builder.CreateCmp (Pred, V0, V1 );
1163- else if (isa<BinaryOperator>(I) )
1164- Scalar = Builder.CreateBinOp ((Instruction::BinaryOps) Opcode, V0, V1 );
1164+ if (CI )
1165+ Scalar = Builder.CreateCmp (CI-> getPredicate (), ScalarOps[ 0 ], ScalarOps[ 1 ] );
1166+ else if (UO || BO )
1167+ Scalar = Builder.CreateNAryOp ( Opcode, ScalarOps );
11651168 else
1166- Scalar = Builder.CreateIntrinsic (
1167- ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID (), {V0, V1});
1169+ Scalar = Builder.CreateIntrinsic (ScalarTy, II->getIntrinsicID (), ScalarOps);
11681170
11691171 Scalar->setName (I.getName () + " .scalar" );
11701172
@@ -1175,16 +1177,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11751177
11761178 // Create a new base vector if the constant folding failed.
11771179 if (!NewVecC) {
1178- if (isa<CmpInst>(I))
1179- NewVecC = Builder.CreateCmp (Pred, VecC0, VecC1);
1180- else if (isa<BinaryOperator>(I))
1181- NewVecC =
1182- Builder.CreateBinOp ((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1180+ SmallVector<Value *> VecCValues;
1181+ VecCValues.reserve (VecCs.size ());
1182+ append_range (VecCValues, VecCs);
1183+ if (CI)
1184+ NewVecC = Builder.CreateCmp (CI->getPredicate (), VecCs[0 ], VecCs[1 ]);
1185+ else if (UO || BO)
1186+ NewVecC = Builder.CreateNAryOp (Opcode, VecCValues);
11831187 else
1184- NewVecC = Builder. CreateIntrinsic (
1185- VecTy, cast<IntrinsicInst>(I). getIntrinsicID (), {VecC0, VecC1} );
1188+ NewVecC =
1189+ Builder. CreateIntrinsic ( VecTy, II-> getIntrinsicID (), VecCValues );
11861190 }
1187- Value *Insert = Builder.CreateInsertElement (NewVecC, Scalar, Index);
1191+ Value *Insert = Builder.CreateInsertElement (NewVecC, Scalar, * Index);
11881192 replaceValue (I, *Insert);
11891193 return true ;
11901194}
@@ -3570,7 +3574,7 @@ bool VectorCombine::run() {
35703574 // This transform works with scalable and fixed vectors
35713575 // TODO: Identify and allow other scalable transforms
35723576 if (IsVectorType) {
3573- MadeChange |= scalarizeBinopOrCmp (I);
3577+ MadeChange |= scalarizeOpOrCmp (I);
35743578 MadeChange |= scalarizeLoadExtract (I);
35753579 MadeChange |= scalarizeVPIntrinsic (I);
35763580 MadeChange |= foldInterleaveIntrinsics (I);
0 commit comments