Skip to content

Commit b5ae180

Browse files
committed
merge InterchangeableInstruction and InterchangeableBinOp
rename isSame to add rename tryAnd to trySet make Mask support MainOp_BIT
1 parent ba9ab59 commit b5ae180

File tree

1 file changed

+62
-107
lines changed

1 file changed

+62
-107
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 62 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -814,34 +814,14 @@ namespace {
814814
/// equivalent forms. For example, multiplication by a power of 2 can be
815815
/// interchanged with a left shift.
816816
///
817-
/// Derived classes implement specific interchange patterns by overriding the
818-
/// virtual methods to define their interchange logic.
819-
///
820817
/// The class maintains a reference to the main instruction (MainOp) and
821818
/// provides methods to:
822-
/// - Check if another instruction is interchangeable (isSame)
819+
/// - Check if the incoming instruction can use the same instruction as MainOp
820+
/// (add)
823821
/// - Get the opcode for the interchangeable form (getOpcode)
824822
/// - Get the operands for the interchangeable form (getOperand)
825-
class InterchangeableInstruction {
826-
protected:
827-
Instruction *const MainOp = nullptr;
828-
829-
public:
830-
InterchangeableInstruction(Instruction *MainOp) : MainOp(MainOp) {}
831-
virtual bool isSame(Instruction *I) const {
832-
return MainOp->getOpcode() == I->getOpcode();
833-
}
834-
virtual unsigned getOpcode() const { return MainOp->getOpcode(); }
835-
virtual SmallVector<Value *> getOperand(Instruction *I) const {
836-
assert(MainOp->getOpcode() == I->getOpcode() &&
837-
"Cannot convert the instruction.");
838-
return SmallVector<Value *>(MainOp->operands());
839-
}
840-
virtual ~InterchangeableInstruction() = default;
841-
};
842-
843-
class InterchangeableBinOp final : public InterchangeableInstruction {
844-
using MaskType = std::uint_fast8_t;
823+
class InterchangeableBinOp {
824+
using MaskType = std::uint_fast16_t;
845825
// Sort SupportedOp because it is used by binary_search.
846826
constexpr static std::initializer_list<unsigned> SupportedOp = {
847827
Instruction::Add, Instruction::Sub, Instruction::Mul, Instruction::Shl,
@@ -855,15 +835,17 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
855835
And_BIT = 0b100000,
856836
Or_BIT = 0b1000000,
857837
Xor_BIT = 0b10000000,
838+
MainOp_BIT = 0b100000000
858839
};
840+
Instruction *MainOp = nullptr;
859841
// The bit it sets represents whether MainOp can be converted to.
860-
mutable MaskType Mask = Xor_BIT | Or_BIT | And_BIT | Sub_BIT | Add_BIT |
861-
Mul_BIT | AShr_BIT | SHL_BIT;
842+
MaskType Mask = MainOp_BIT | Xor_BIT | Or_BIT | And_BIT | Sub_BIT | Add_BIT |
843+
Mul_BIT | AShr_BIT | SHL_BIT;
862844
// We cannot create an interchangeable instruction that does not exist in VL.
863845
// For example, VL [x + 0, y * 1] can be converted to [x << 0, y << 0], but
864846
// 'shl' does not exist in VL. In the end, we convert VL to [x * 1, y * 1].
865847
// SeenBefore is used to know what operations have been seen before.
866-
mutable MaskType SeenBefore = 0;
848+
MaskType SeenBefore = 0;
867849

868850
/// Return a non-nullptr if either operand of I is a ConstantInt.
869851
static std::pair<ConstantInt *, unsigned>
@@ -910,7 +892,7 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
910892
llvm_unreachable("Unsupported opcode.");
911893
}
912894

913-
bool tryAnd(MaskType X) const {
895+
bool trySet(MaskType X) {
914896
if (Mask & X) {
915897
Mask &= X;
916898
return true;
@@ -919,39 +901,47 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
919901
}
920902

921903
public:
922-
using InterchangeableInstruction::InterchangeableInstruction;
923-
bool isSame(Instruction *I) const override {
904+
InterchangeableBinOp(Instruction *MainOp) : MainOp(MainOp) {}
905+
bool add(Instruction *I) {
924906
unsigned Opcode = I->getOpcode();
925-
if (!binary_search(SupportedOp, Opcode))
907+
assert(is_sorted(SupportedOp) && "SupportedOp is not sorted.");
908+
if (!binary_search(SupportedOp, Opcode)) {
909+
if (MainOp->getOpcode() == Opcode)
910+
return trySet(MainOp_BIT);
926911
return false;
912+
}
927913
SeenBefore |= opcodeToMask(Opcode);
928914
ConstantInt *CI = isBinOpWithConstantInt(I).first;
929915
if (CI) {
916+
constexpr MaskType CanBeAll = Xor_BIT | Or_BIT | And_BIT | Sub_BIT |
917+
Add_BIT | Mul_BIT | AShr_BIT | SHL_BIT;
930918
const APInt &CIValue = CI->getValue();
931919
switch (Opcode) {
932920
case Instruction::Shl:
933921
if (CIValue.isZero())
934-
return true;
935-
return tryAnd(Mul_BIT | SHL_BIT);
922+
return trySet(CanBeAll);
923+
return trySet(Mul_BIT | SHL_BIT);
936924
case Instruction::Mul:
937925
if (CIValue.isOne())
938-
return true;
926+
return trySet(CanBeAll);
939927
if (CIValue.isPowerOf2())
940-
return tryAnd(Mul_BIT | SHL_BIT);
928+
return trySet(Mul_BIT | SHL_BIT);
941929
break;
942930
case Instruction::And:
943931
if (CIValue.isAllOnes())
944-
return true;
932+
return trySet(CanBeAll);
945933
break;
946934
default:
947935
if (CIValue.isZero())
948-
return true;
936+
return trySet(CanBeAll);
949937
break;
950938
}
951939
}
952-
return tryAnd(opcodeToMask(Opcode));
940+
return trySet(opcodeToMask(Opcode));
953941
}
954-
unsigned getOpcode() const override {
942+
unsigned getOpcode() const {
943+
if (Mask & MainOp_BIT)
944+
return MainOp->getOpcode();
955945
MaskType Candidate = Mask & SeenBefore;
956946
if (Candidate & SHL_BIT)
957947
return Instruction::Shl;
@@ -971,12 +961,12 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
971961
return Instruction::Xor;
972962
llvm_unreachable("Cannot find interchangeable instruction.");
973963
}
974-
SmallVector<Value *> getOperand(Instruction *I) const override {
964+
SmallVector<Value *> getOperand(Instruction *I) const {
975965
unsigned ToOpcode = I->getOpcode();
976-
assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
977966
unsigned FromOpcode = MainOp->getOpcode();
978967
if (FromOpcode == ToOpcode)
979968
return SmallVector<Value *>(MainOp->operands());
969+
assert(binary_search(SupportedOp, ToOpcode) && "Unsupported opcode.");
980970
auto [CI, Pos] = isBinOpWithConstantInt(MainOp);
981971
const APInt &FromCIValue = CI->getValue();
982972
unsigned FromCIValueBitWidth = FromCIValue.getBitWidth();
@@ -1023,27 +1013,12 @@ class InterchangeableBinOp final : public InterchangeableInstruction {
10231013
}
10241014
};
10251015

1026-
static SmallVector<std::unique_ptr<InterchangeableInstruction>>
1027-
getInterchangeableInstruction(Instruction *MainOp) {
1028-
SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate;
1029-
Candidate.push_back(std::make_unique<InterchangeableInstruction>(MainOp));
1030-
if (MainOp->isBinaryOp())
1031-
Candidate.push_back(std::make_unique<InterchangeableBinOp>(MainOp));
1032-
return Candidate;
1033-
}
1034-
1035-
static bool getInterchangeableInstruction(
1036-
SmallVector<std::unique_ptr<InterchangeableInstruction>> &Candidate,
1037-
Instruction *I) {
1038-
auto Iter = std::stable_partition(
1039-
Candidate.begin(), Candidate.end(),
1040-
[&](const std::unique_ptr<InterchangeableInstruction> &C) {
1041-
return C->isSame(I);
1042-
});
1043-
if (Iter == Candidate.begin())
1044-
return false;
1045-
Candidate.erase(Iter, Candidate.end());
1046-
return true;
1016+
static std::optional<InterchangeableBinOp> isConvertible(Instruction *From,
1017+
Instruction *To) {
1018+
InterchangeableBinOp Converter(From);
1019+
if (Converter.add(From) && Converter.add(To))
1020+
return Converter;
1021+
return {};
10471022
}
10481023

10491024
static bool isConvertible(Instruction *I, Instruction *MainOp,
@@ -1056,16 +1031,7 @@ static bool isConvertible(Instruction *I, Instruction *MainOp,
10561031
return true;
10571032
if (!I->isBinaryOp())
10581033
return false;
1059-
SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
1060-
getInterchangeableInstruction(I));
1061-
for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
1062-
if (C->isSame(I) && C->isSame(MainOp))
1063-
return true;
1064-
Candidate = getInterchangeableInstruction(I);
1065-
for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
1066-
if (C->isSame(I) && C->isSame(AltOp))
1067-
return true;
1068-
return false;
1034+
return isConvertible(I, MainOp) || isConvertible(I, AltOp);
10691035
}
10701036

10711037
static std::pair<Instruction *, SmallVector<Value *>>
@@ -1077,15 +1043,12 @@ convertTo(Instruction *I, Instruction *MainOp, Instruction *AltOp) {
10771043
if (I->getOpcode() == AltOp->getOpcode())
10781044
return std::make_pair(AltOp, SmallVector<Value *>(I->operands()));
10791045
assert(I->isBinaryOp() && "Cannot convert the instruction.");
1080-
SmallVector<std::unique_ptr<InterchangeableInstruction>> Candidate(
1081-
getInterchangeableInstruction(I));
1082-
for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
1083-
if (C->isSame(I) && C->isSame(MainOp))
1084-
return std::make_pair(MainOp, C->getOperand(MainOp));
1085-
Candidate = getInterchangeableInstruction(I);
1086-
for (std::unique_ptr<InterchangeableInstruction> &C : Candidate)
1087-
if (C->isSame(I) && C->isSame(AltOp))
1088-
return std::make_pair(AltOp, C->getOperand(AltOp));
1046+
std::optional<InterchangeableBinOp> Converter(isConvertible(I, MainOp));
1047+
if (Converter)
1048+
return std::make_pair(MainOp, Converter->getOperand(MainOp));
1049+
Converter = isConvertible(I, AltOp);
1050+
if (Converter)
1051+
return std::make_pair(AltOp, Converter->getOperand(AltOp));
10891052
llvm_unreachable("Cannot convert the instruction.");
10901053
}
10911054

@@ -1209,11 +1172,8 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
12091172
unsigned Opcode = MainOp->getOpcode();
12101173
unsigned AltOpcode = Opcode;
12111174

1212-
SmallVector<std::unique_ptr<InterchangeableInstruction>>
1213-
InterchangeableInstructionCandidate(
1214-
getInterchangeableInstruction(MainOp));
1215-
SmallVector<std::unique_ptr<InterchangeableInstruction>>
1216-
AlternateInterchangeableInstructionCandidate;
1175+
InterchangeableBinOp InterchangeableConverter(MainOp);
1176+
std::optional<InterchangeableBinOp> AlternateInterchangeableConverter;
12171177
bool SwappedPredsCompatible = IsCmpOp && [&]() {
12181178
SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
12191179
UniquePreds.insert(BasePred);
@@ -1260,17 +1220,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
12601220
return InstructionsState::invalid();
12611221
unsigned InstOpcode = I->getOpcode();
12621222
if (IsBinOp && isa<BinaryOperator>(I)) {
1263-
if (getInterchangeableInstruction(InterchangeableInstructionCandidate, I))
1223+
if (InterchangeableConverter.add(I))
12641224
continue;
1265-
if (AlternateInterchangeableInstructionCandidate.empty()) {
1225+
if (!AlternateInterchangeableConverter) {
12661226
if (!isValidForAlternation(Opcode) ||
12671227
!isValidForAlternation(InstOpcode))
12681228
return InstructionsState::invalid();
1269-
AlternateInterchangeableInstructionCandidate =
1270-
getInterchangeableInstruction(I);
1229+
AlternateInterchangeableConverter = InterchangeableBinOp(I);
12711230
}
1272-
if (getInterchangeableInstruction(
1273-
AlternateInterchangeableInstructionCandidate, I))
1231+
if (AlternateInterchangeableConverter->add(I))
12741232
continue;
12751233
} else if (IsCastOp && isa<CastInst>(I)) {
12761234
Value *Op0 = MainOp->getOperand(0);
@@ -1374,25 +1332,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
13741332

13751333
if (IsBinOp) {
13761334
auto FindOp =
1377-
[&](ArrayRef<std::unique_ptr<InterchangeableInstruction>> Candidate) {
1378-
for (const std::unique_ptr<InterchangeableInstruction> &I :
1379-
Candidate) {
1380-
unsigned InterchangeableInstructionOpcode = I->getOpcode();
1381-
for (Value *V : VL) {
1382-
if (isa<PoisonValue>(V))
1383-
continue;
1384-
auto *Inst = cast<Instruction>(V);
1385-
if (Inst->getOpcode() == InterchangeableInstructionOpcode)
1386-
return Inst;
1387-
}
1335+
[&](const InterchangeableBinOp &Converter) {
1336+
unsigned InterchangeableInstructionOpcode = Converter.getOpcode();
1337+
for (Value *V : VL) {
1338+
if (isa<PoisonValue>(V))
1339+
continue;
1340+
auto *Inst = cast<Instruction>(V);
1341+
if (Inst->getOpcode() == InterchangeableInstructionOpcode)
1342+
return Inst;
13881343
}
13891344
llvm_unreachable(
13901345
"Cannot find the candidate instruction for InstructionsState.");
13911346
};
1392-
MainOp = FindOp(InterchangeableInstructionCandidate);
1393-
AltOp = AlternateInterchangeableInstructionCandidate.empty()
1394-
? MainOp
1395-
: FindOp(AlternateInterchangeableInstructionCandidate);
1347+
MainOp = FindOp(InterchangeableConverter);
1348+
AltOp = AlternateInterchangeableConverter
1349+
? FindOp(*AlternateInterchangeableConverter)
1350+
: MainOp;
13961351
}
13971352
return InstructionsState(MainOp, AltOp);
13981353
}

0 commit comments

Comments
 (0)