@@ -4365,14 +4365,18 @@ class BoUpSLP {
4365
4365
} else {
4366
4366
// Build a map for gathered scalars to the nodes where they are used.
4367
4367
bool AllConstsOrCasts = true;
4368
- for (Value *V : VL)
4368
+ for (Value *V : VL) {
4369
+ if (S && S.areInstructionsWithCopyableElements() &&
4370
+ S.isCopyableElement(V))
4371
+ Last->addCopyableElement(V);
4369
4372
if (!isConstant(V)) {
4370
4373
auto *I = dyn_cast<CastInst>(V);
4371
4374
AllConstsOrCasts &= I && I->getType()->isIntegerTy();
4372
4375
if (UserTreeIdx.EdgeIdx != UINT_MAX || !UserTreeIdx.UserTE ||
4373
4376
!UserTreeIdx.UserTE->isGather())
4374
4377
ValueToGatherNodes.try_emplace(V).first->getSecond().insert(Last);
4375
4378
}
4379
+ }
4376
4380
if (AllConstsOrCasts)
4377
4381
CastMaxMinBWSizes =
4378
4382
std::make_pair(std::numeric_limits<unsigned>::max(), 1);
@@ -10564,35 +10568,41 @@ class InstructionsCompatibilityAnalysis {
10564
10568
unsigned MainOpcode = 0;
10565
10569
Instruction *MainOp = nullptr;
10566
10570
10571
+ /// Checks if the opcode is supported as the main opcode for copyable
10572
+ /// elements.
10573
+ static bool isSupportedOpcode(const unsigned Opcode) {
10574
+ return Opcode == Instruction::Add || Opcode == Instruction::LShr;
10575
+ }
10576
+
10567
10577
/// Identifies the best candidate value, which represents main opcode
10568
10578
/// operation.
10569
10579
/// Currently the best candidate is the Add instruction with the parent
10570
10580
/// block with the highest DFS incoming number (block, that dominates other).
10571
10581
void findAndSetMainInstruction(ArrayRef<Value *> VL, const BoUpSLP &R) {
10572
10582
BasicBlock *Parent = nullptr;
10573
10583
// Checks if the instruction has supported opcode.
10574
- auto IsSupportedOpcode = [&](Instruction *I) {
10575
- return I && I->getOpcode() == Instruction::Add &&
10584
+ auto IsSupportedInstruction = [&](Instruction *I) {
10585
+ return I && isSupportedOpcode( I->getOpcode()) &&
10576
10586
(!doesNotNeedToBeScheduled(I) || !R.isVectorized(I));
10577
10587
};
10578
10588
// Exclude operands instructions immediately to improve compile time, it
10579
10589
// will be unable to schedule anyway.
10580
10590
SmallDenseSet<Value *, 8> Operands;
10591
+ SmallMapVector<unsigned, SmallVector<Instruction *>, 4> Candidates;
10581
10592
for (Value *V : VL) {
10582
10593
auto *I = dyn_cast<Instruction>(V);
10583
10594
if (!I)
10584
10595
continue;
10585
10596
if (!DT.isReachableFromEntry(I->getParent()))
10586
10597
continue;
10587
- if (!MainOp ) {
10588
- MainOp = I ;
10598
+ if (Candidates.empty() ) {
10599
+ Candidates.try_emplace(I->getOpcode()).first->second.push_back(I) ;
10589
10600
Parent = I->getParent();
10590
10601
Operands.insert(I->op_begin(), I->op_end());
10591
10602
continue;
10592
10603
}
10593
10604
if (Parent == I->getParent()) {
10594
- if (!IsSupportedOpcode(MainOp) && !Operands.contains(I))
10595
- MainOp = I;
10605
+ Candidates.try_emplace(I->getOpcode()).first->second.push_back(I);
10596
10606
Operands.insert(I->op_begin(), I->op_end());
10597
10607
continue;
10598
10608
}
@@ -10604,24 +10614,35 @@ class InstructionsCompatibilityAnalysis {
10604
10614
(NodeA->getDFSNumIn() == NodeB->getDFSNumIn()) &&
10605
10615
"Different nodes should have different DFS numbers");
10606
10616
if (NodeA->getDFSNumIn() < NodeB->getDFSNumIn()) {
10607
- MainOp = I;
10617
+ Candidates.clear();
10618
+ Candidates.try_emplace(I->getOpcode()).first->second.push_back(I);
10608
10619
Parent = I->getParent();
10609
10620
Operands.clear();
10610
10621
Operands.insert(I->op_begin(), I->op_end());
10611
10622
}
10612
10623
}
10613
- if (!IsSupportedOpcode(MainOp) || Operands.contains(MainOp)) {
10614
- MainOp = nullptr;
10615
- return;
10624
+ unsigned BestOpcodeNum = 0;
10625
+ MainOp = nullptr;
10626
+ for (const auto &P : Candidates) {
10627
+ if (P.second.size() < BestOpcodeNum)
10628
+ continue;
10629
+ for (Instruction *I : P.second) {
10630
+ if (IsSupportedInstruction(I) && !Operands.contains(I)) {
10631
+ MainOp = I;
10632
+ BestOpcodeNum = P.second.size();
10633
+ break;
10634
+ }
10635
+ }
10616
10636
}
10617
- MainOpcode = MainOp->getOpcode();
10637
+ if (MainOp)
10638
+ MainOpcode = MainOp->getOpcode();
10618
10639
}
10619
10640
10620
10641
/// Returns the idempotent value for the \p MainOp with the detected \p
10621
10642
/// MainOpcode. For Add, returns 0. For Or, it should choose between false and
10622
10643
/// the operand itself, since V or V == V.
10623
10644
Value *selectBestIdempotentValue() const {
10624
- assert(MainOpcode == Instruction::Add && "Unsupported opcode");
10645
+ assert(isSupportedOpcode( MainOpcode) && "Unsupported opcode");
10625
10646
return ConstantExpr::getBinOpIdentity(MainOpcode, MainOp->getType(),
10626
10647
!MainOp->isCommutative());
10627
10648
}
@@ -10634,13 +10655,8 @@ class InstructionsCompatibilityAnalysis {
10634
10655
return {V, V};
10635
10656
if (!S.isCopyableElement(V))
10636
10657
return convertTo(cast<Instruction>(V), S).second;
10637
- switch (MainOpcode) {
10638
- case Instruction::Add:
10639
- return {V, selectBestIdempotentValue()};
10640
- default:
10641
- break;
10642
- }
10643
- llvm_unreachable("Unsupported opcode");
10658
+ assert(isSupportedOpcode(MainOpcode) && "Unsupported opcode");
10659
+ return {V, selectBestIdempotentValue()};
10644
10660
}
10645
10661
10646
10662
/// Builds operands for the original instructions.
@@ -10853,6 +10869,21 @@ class InstructionsCompatibilityAnalysis {
10853
10869
}
10854
10870
if (!Res)
10855
10871
return InstructionsState::invalid();
10872
+ constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
10873
+ InstructionCost ScalarCost = TTI.getInstructionCost(S.getMainOp(), Kind);
10874
+ InstructionCost VectorCost;
10875
+ FixedVectorType *VecTy =
10876
+ getWidenedType(S.getMainOp()->getType(), VL.size());
10877
+ switch (MainOpcode) {
10878
+ case Instruction::Add:
10879
+ case Instruction::LShr:
10880
+ VectorCost = TTI.getArithmeticInstrCost(MainOpcode, VecTy, Kind);
10881
+ break;
10882
+ default:
10883
+ llvm_unreachable("Unexpected instruction.");
10884
+ }
10885
+ if (VectorCost > ScalarCost)
10886
+ return InstructionsState::invalid();
10856
10887
return S;
10857
10888
}
10858
10889
assert(Operands.size() == 2 && "Unexpected number of operands!");
@@ -21090,6 +21121,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(
21090
21121
ArrayRef<Value *> Op = EI.UserTE->getOperand(EI.EdgeIdx);
21091
21122
const auto *It = find(Op, CD->getInst());
21092
21123
assert(It != Op.end() && "Lane not set");
21124
+ SmallPtrSet<Instruction *, 4> Visited;
21093
21125
do {
21094
21126
int Lane = std::distance(Op.begin(), It);
21095
21127
assert(Lane >= 0 && "Lane not set");
@@ -21111,13 +21143,15 @@ void BoUpSLP::BlockScheduling::calculateDependencies(
21111
21143
(InsertInReadyList && UseSD->isReady()))
21112
21144
WorkList.push_back(UseSD);
21113
21145
}
21114
- } else if (ScheduleData *UseSD = getScheduleData(In)) {
21115
- CD->incDependencies();
21116
- if (!UseSD->isScheduled())
21117
- CD->incrementUnscheduledDeps(1);
21118
- if (!UseSD->hasValidDependencies() ||
21119
- (InsertInReadyList && UseSD->isReady()))
21120
- WorkList.push_back(UseSD);
21146
+ } else if (Visited.insert(In).second) {
21147
+ if (ScheduleData *UseSD = getScheduleData(In)) {
21148
+ CD->incDependencies();
21149
+ if (!UseSD->isScheduled())
21150
+ CD->incrementUnscheduledDeps(1);
21151
+ if (!UseSD->hasValidDependencies() ||
21152
+ (InsertInReadyList && UseSD->isReady()))
21153
+ WorkList.push_back(UseSD);
21154
+ }
21121
21155
}
21122
21156
It = find(make_range(std::next(It), Op.end()), CD->getInst());
21123
21157
} while (It != Op.end());
@@ -21875,9 +21909,11 @@ bool BoUpSLP::collectValuesToDemote(
21875
21909
return all_of(E.Scalars, [&](Value *V) {
21876
21910
if (isa<PoisonValue>(V))
21877
21911
return true;
21912
+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
21913
+ if (E.isCopyableElement(V))
21914
+ return MaskedValueIsZero(V, ShiftedBits, SimplifyQuery(*DL));
21878
21915
auto *I = cast<Instruction>(V);
21879
21916
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
21880
- APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
21881
21917
return AmtKnownBits.getMaxValue().ult(BitWidth) &&
21882
21918
MaskedValueIsZero(I->getOperand(0), ShiftedBits,
21883
21919
SimplifyQuery(*DL));
0 commit comments