@@ -842,8 +842,123 @@ class InstructionsState {
842842 static InstructionsState invalid() { return {nullptr, nullptr}; }
843843};
844844
845+ struct InterchangeableInstruction {
846+ unsigned Opcode;
847+ SmallVector<Value *> Ops;
848+ template <class... ArgTypes>
849+ InterchangeableInstruction(unsigned Opcode, ArgTypes &&...Args)
850+ : Opcode(Opcode), Ops{std::forward<decltype(Args)>(Args)...} {}
851+ };
852+
853+ bool operator<(const InterchangeableInstruction &LHS,
854+ const InterchangeableInstruction &RHS) {
855+ return LHS.Opcode < RHS.Opcode;
856+ }
857+
845858} // end anonymous namespace
846859
860+ /// \returns a sorted list of interchangeable instructions by instruction opcode
861+ /// that \p I can be converted to.
862+ /// e.g.,
863+ /// x << y -> x * (2^y)
864+ /// x << 1 -> x * 2
865+ /// x << 0 -> x * 1 -> x - 0 -> x + 0 -> x & 11...1 -> x | 0
866+ /// x * 0 -> x & 0
867+ /// x * -1 -> 0 - x
868+ /// TODO: support more patterns
869+ static SmallVector<InterchangeableInstruction>
870+ getInterchangeableInstruction(Instruction *I) {
871+ // PII = Possible Interchangeable Instruction
872+ SmallVector<InterchangeableInstruction> PII;
873+ unsigned Opcode = I->getOpcode();
874+ PII.emplace_back(Opcode, I->operands());
875+ if (!is_contained({Instruction::Shl, Instruction::Mul, Instruction::Sub,
876+ Instruction::Add},
877+ Opcode))
878+ return PII;
879+ Constant *C;
880+ if (match(I, m_BinOp(m_Value(), m_Constant(C)))) {
881+ ConstantInt *V = nullptr;
882+ if (auto *CI = dyn_cast<ConstantInt>(C)) {
883+ V = CI;
884+ } else if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
885+ if (auto *CI = dyn_cast_if_present<ConstantInt>(CDV->getSplatValue()))
886+ V = CI;
887+ }
888+ if (!V)
889+ return PII;
890+ Value *Op0 = I->getOperand(0);
891+ Type *Op1Ty = I->getOperand(1)->getType();
892+ const APInt &Op1Int = V->getValue();
893+ Constant *Zero =
894+ ConstantInt::get(Op1Ty, APInt::getZero(Op1Int.getBitWidth()));
895+ Constant *UnsignedMax =
896+ ConstantInt::get(Op1Ty, APInt::getMaxValue(Op1Int.getBitWidth()));
897+ switch (Opcode) {
898+ case Instruction::Shl: {
899+ PII.emplace_back(Instruction::Mul, Op0,
900+ ConstantInt::get(Op1Ty, 1 << Op1Int.getZExtValue()));
901+ if (Op1Int.isZero()) {
902+ PII.emplace_back(Instruction::Sub, Op0, Zero);
903+ PII.emplace_back(Instruction::Add, Op0, Zero);
904+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
905+ PII.emplace_back(Instruction::Or, Op0, Zero);
906+ }
907+ break;
908+ }
909+ case Instruction::Mul: {
910+ if (Op1Int.isOne()) {
911+ PII.emplace_back(Instruction::Sub, Op0, Zero);
912+ PII.emplace_back(Instruction::Add, Op0, Zero);
913+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
914+ PII.emplace_back(Instruction::Or, Op0, Zero);
915+ } else if (Op1Int.isZero()) {
916+ PII.emplace_back(Instruction::And, Op0, Zero);
917+ } else if (Op1Int.isAllOnes()) {
918+ PII.emplace_back(Instruction::Sub, Zero, Op0);
919+ }
920+ break;
921+ }
922+ case Instruction::Sub:
923+ if (Op1Int.isZero()) {
924+ PII.emplace_back(Instruction::Add, Op0, Zero);
925+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
926+ PII.emplace_back(Instruction::Or, Op0, Zero);
927+ }
928+ break;
929+ case Instruction::Add:
930+ if (Op1Int.isZero()) {
931+ PII.emplace_back(Instruction::And, Op0, UnsignedMax);
932+ PII.emplace_back(Instruction::Or, Op0, Zero);
933+ }
934+ break;
935+ }
936+ }
937+ // std::set_intersection requires a sorted range.
938+ sort(PII);
939+ return PII;
940+ }
941+
942+ /// \returns the Op and operands which \p I convert to.
943+ static std::pair<Value *, SmallVector<Value *>>
944+ getInterchangeableInstruction(Instruction *I, Instruction *MainOp,
945+ Instruction *AltOp) {
946+ SmallVector<InterchangeableInstruction> IIList =
947+ getInterchangeableInstruction(I);
948+ const auto *Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
949+ return II.Opcode == MainOp->getOpcode();
950+ });
951+ if (Iter == IIList.end()) {
952+ Iter = find_if(IIList, [&](const InterchangeableInstruction &II) {
953+ return II.Opcode == AltOp->getOpcode();
954+ });
955+ assert(Iter != IIList.end() &&
956+ "Cannot find an interchangeable instruction.");
957+ return std::make_pair(AltOp, Iter->Ops);
958+ }
959+ return std::make_pair(MainOp, Iter->Ops);
960+ }
961+
847962/// \returns true if \p Opcode is allowed as part of the main/alternate
848963/// instruction for SLP vectorization.
849964///
@@ -957,6 +1072,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
9571072 return InstructionsState::invalid();
9581073 }
9591074 bool AnyPoison = InstCnt != VL.size();
1075+ // Currently, this is only used for binary ops.
1076+ // TODO: support all instructions
1077+ SmallVector<InterchangeableInstruction> InterchangeableOpcode =
1078+ getInterchangeableInstruction(cast<Instruction>(V));
1079+ SmallVector<InterchangeableInstruction> AlternateInterchangeableOpcode;
1080+ auto UpdateInterchangeableOpcode =
1081+ [](SmallVector<InterchangeableInstruction> &LHS,
1082+ ArrayRef<InterchangeableInstruction> RHS) {
1083+ SmallVector<InterchangeableInstruction> NewInterchangeableOpcode;
1084+ std::set_intersection(LHS.begin(), LHS.end(), RHS.begin(), RHS.end(),
1085+ std::back_inserter(NewInterchangeableOpcode));
1086+ if (NewInterchangeableOpcode.empty())
1087+ return false;
1088+ LHS.swap(NewInterchangeableOpcode);
1089+ return true;
1090+ };
9601091 for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
9611092 auto *I = dyn_cast<Instruction>(VL[Cnt]);
9621093 if (!I)
@@ -969,14 +1100,32 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
9691100 return InstructionsState::invalid();
9701101 unsigned InstOpcode = I->getOpcode();
9711102 if (IsBinOp && isa<BinaryOperator>(I)) {
972- if (InstOpcode == Opcode || InstOpcode == AltOpcode)
1103+ SmallVector<InterchangeableInstruction> ThisInterchangeableOpcode(
1104+ getInterchangeableInstruction(I));
1105+ if (UpdateInterchangeableOpcode(InterchangeableOpcode,
1106+ ThisInterchangeableOpcode))
9731107 continue;
974- if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
975- isValidForAlternation(Opcode)) {
976- AltOpcode = InstOpcode;
977- AltIndex = Cnt;
1108+ if (AlternateInterchangeableOpcode.empty()) {
1109+ InterchangeableOpcode.erase(
1110+ remove_if(InterchangeableOpcode,
1111+ [](const InterchangeableInstruction &I) {
1112+ return !isValidForAlternation(I.Opcode);
1113+ }),
1114+ InterchangeableOpcode.end());
1115+ ThisInterchangeableOpcode.erase(
1116+ remove_if(ThisInterchangeableOpcode,
1117+ [](const InterchangeableInstruction &I) {
1118+ return !isValidForAlternation(I.Opcode);
1119+ }),
1120+ ThisInterchangeableOpcode.end());
1121+ if (InterchangeableOpcode.empty() || ThisInterchangeableOpcode.empty())
1122+ return InstructionsState::invalid();
1123+ AlternateInterchangeableOpcode.swap(ThisInterchangeableOpcode);
9781124 continue;
9791125 }
1126+ if (UpdateInterchangeableOpcode(AlternateInterchangeableOpcode,
1127+ ThisInterchangeableOpcode))
1128+ continue;
9801129 } else if (IsCastOp && isa<CastInst>(I)) {
9811130 Value *Op0 = IBase->getOperand(0);
9821131 Type *Ty0 = Op0->getType();
@@ -1077,6 +1226,24 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
10771226 return InstructionsState::invalid();
10781227 }
10791228
1229+ if (IsBinOp) {
1230+ auto FindOp = [&](ArrayRef<InterchangeableInstruction> CandidateOp) {
1231+ for (Value *V : VL) {
1232+ if (isa<PoisonValue>(V))
1233+ continue;
1234+ for (const InterchangeableInstruction &I : CandidateOp)
1235+ if (cast<Instruction>(V)->getOpcode() == I.Opcode)
1236+ return cast<Instruction>(V);
1237+ }
1238+ llvm_unreachable(
1239+ "Cannot find the candidate instruction for InstructionsState.");
1240+ };
1241+ Instruction *MainOp = FindOp(InterchangeableOpcode);
1242+ Instruction *AltOp = AlternateInterchangeableOpcode.empty()
1243+ ? MainOp
1244+ : FindOp(AlternateInterchangeableOpcode);
1245+ return InstructionsState(MainOp, AltOp);
1246+ }
10801247 return InstructionsState(cast<Instruction>(V),
10811248 cast<Instruction>(VL[AltIndex]));
10821249}
@@ -2407,42 +2574,46 @@ class BoUpSLP {
24072574 }
24082575
24092576 /// Go through the instructions in VL and append their operands.
2410- void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *VL0) {
2577+ void appendOperandsOfVL(ArrayRef<Value *> VL, Instruction *MainOp,
2578+ Instruction *AltOp) {
24112579 assert(!VL.empty() && "Bad VL");
24122580 assert((empty() || VL.size() == getNumLanes()) &&
24132581 "Expected same number of lanes");
24142582 // IntrinsicInst::isCommutative returns true if swapping the first "two"
24152583 // arguments to the intrinsic produces the same result.
24162584 constexpr unsigned IntrinsicNumOperands = 2;
2417- unsigned NumOperands = VL0 ->getNumOperands();
2418- ArgSize = isa<IntrinsicInst>(VL0 ) ? IntrinsicNumOperands : NumOperands;
2585+ unsigned NumOperands = MainOp ->getNumOperands();
2586+ ArgSize = isa<IntrinsicInst>(MainOp ) ? IntrinsicNumOperands : NumOperands;
24192587 OpsVec.resize(NumOperands);
24202588 unsigned NumLanes = VL.size();
2421- for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) {
2589+ for (unsigned OpIdx : seq<unsigned>( NumOperands))
24222590 OpsVec[OpIdx].resize(NumLanes);
2423- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
2424- assert((isa<Instruction>(VL[Lane]) || isa<PoisonValue>(VL[Lane])) &&
2425- "Expected instruction or poison value");
2426- // Our tree has just 3 nodes: the root and two operands.
2427- // It is therefore trivial to get the APO. We only need to check the
2428- // opcode of VL[Lane] and whether the operand at OpIdx is the LHS or
2429- // RHS operand. The LHS operand of both add and sub is never attached
2430- // to an inversese operation in the linearized form, therefore its APO
2431- // is false. The RHS is true only if VL[Lane] is an inverse operation.
2432-
2433- // Since operand reordering is performed on groups of commutative
2434- // operations or alternating sequences (e.g., +, -), we can safely
2435- // tell the inverse operations by checking commutativity.
2436- if (isa<PoisonValue>(VL[Lane])) {
2591+ for (auto [Lane, V] : enumerate(VL)) {
2592+ assert((isa<Instruction>(V) || isa<PoisonValue>(V)) &&
2593+ "Expected instruction or poison value");
2594+ if (isa<PoisonValue>(V)) {
2595+ for (unsigned OpIdx : seq<unsigned>(NumOperands))
24372596 OpsVec[OpIdx][Lane] = {
2438- PoisonValue::get(VL0 ->getOperand(OpIdx)->getType()), true,
2597+ PoisonValue::get(MainOp ->getOperand(OpIdx)->getType()), true,
24392598 false};
2440- continue;
2441- }
2442- bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane]));
2599+ continue;
2600+ }
2601+ auto [SelectedOp, Ops] =
2602+ getInterchangeableInstruction(cast<Instruction>(V), MainOp, AltOp);
2603+ // Our tree has just 3 nodes: the root and two operands.
2604+ // It is therefore trivial to get the APO. We only need to check the
2605+ // opcode of V and whether the operand at OpIdx is the LHS or RHS
2606+ // operand. The LHS operand of both add and sub is never attached to an
2607+ // inversese operation in the linearized form, therefore its APO is
2608+ // false. The RHS is true only if V is an inverse operation.
2609+
2610+ // Since operand reordering is performed on groups of commutative
2611+ // operations or alternating sequences (e.g., +, -), we can safely
2612+ // tell the inverse operations by checking commutativity.
2613+ bool IsInverseOperation = !isCommutative(cast<Instruction>(SelectedOp));
2614+ for (unsigned OpIdx : seq<unsigned>(NumOperands)) {
24432615 bool APO = (OpIdx == 0) ? false : IsInverseOperation;
2444- OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx),
2445- APO, false};
2616+ OpsVec[OpIdx][Lane] = {Ops[OpIdx], APO, false};
24462617 }
24472618 }
24482619 }
@@ -2549,11 +2720,12 @@ class BoUpSLP {
25492720
25502721 public:
25512722 /// Initialize with all the operands of the instruction vector \p RootVL.
2552- VLOperands(ArrayRef<Value *> RootVL, Instruction *VL0, const BoUpSLP &R)
2723+ VLOperands(ArrayRef<Value *> RootVL, Instruction *MainOp,
2724+ Instruction *AltOp, const BoUpSLP &R)
25532725 : TLI(*R.TLI), DL(*R.DL), SE(*R.SE), R(R),
2554- L(R.LI->getLoopFor((VL0 ->getParent() ))) {
2726+ L(R.LI->getLoopFor(MainOp ->getParent())) {
25552727 // Append all the operands of RootVL.
2556- appendOperandsOfVL(RootVL, VL0 );
2728+ appendOperandsOfVL(RootVL, MainOp, AltOp );
25572729 }
25582730
25592731 /// \Returns a value vector with the operands across all lanes for the
@@ -3345,7 +3517,7 @@ class BoUpSLP {
33453517
33463518 /// Set this bundle's operand from Scalars.
33473519 void setOperand(const BoUpSLP &R, bool RequireReorder = false) {
3348- VLOperands Ops(Scalars, MainOp, R);
3520+ VLOperands Ops(Scalars, MainOp, AltOp, R);
33493521 if (RequireReorder)
33503522 Ops.reorder();
33513523 for (unsigned I : seq<unsigned>(MainOp->getNumOperands()))
@@ -8561,7 +8733,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
85618733 LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n");
85628734
85638735 ValueList Left, Right;
8564- VLOperands Ops(VL, VL0, *this);
8736+ VLOperands Ops(VL, VL0, S.getAltOp(), *this);
85658737 if (cast<CmpInst>(VL0)->isCommutative()) {
85668738 // Commutative predicate - collect + sort operands of the instructions
85678739 // so that each side is more likely to have the same opcode.
@@ -15619,7 +15791,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1561915791 Value *V = Builder.CreateBinOp(
1562015792 static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
1562115793 RHS);
15622- propagateIRFlags(V, E->Scalars, VL0 , It == MinBWs.end());
15794+ propagateIRFlags(V, E->Scalars, nullptr , It == MinBWs.end());
1562315795 if (auto *I = dyn_cast<Instruction>(V)) {
1562415796 V = ::propagateMetadata(I, E->Scalars);
1562515797 // Drop nuw flags for abs(sub(commutative), true).
0 commit comments