Skip to content

Commit 66be00d

Browse files
authored
[VPlan] Introduce m_Cmp; match more compares (#154771)
Extend [Specific]Cmp_match to handle floating-point compares, and introduce m_Cmp that matches both integer and floating-point compares. Use it in simplifyRecipe to match and simplify the general case of compares. The change has necessitated a bugfix in VPReplicateRecipe::execute.
1 parent feac561 commit 66be00d

File tree

7 files changed

+126
-108
lines changed

7 files changed

+126
-108
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7002,12 +7002,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
70027002
if (Instruction *UI = GetInstructionForCost(&R)) {
70037003
// If we adjusted the predicate of the recipe, the cost in the legacy
70047004
// cost model may be different.
7005-
if (auto *WidenCmp = dyn_cast<VPWidenRecipe>(&R)) {
7006-
if ((WidenCmp->getOpcode() == Instruction::ICmp ||
7007-
WidenCmp->getOpcode() == Instruction::FCmp) &&
7008-
WidenCmp->getPredicate() != cast<CmpInst>(UI)->getPredicate())
7009-
return true;
7010-
}
7005+
using namespace VPlanPatternMatch;
7006+
CmpPredicate Pred;
7007+
if (match(&R, m_Cmp(Pred, m_VPValue(), m_VPValue())) &&
7008+
cast<VPRecipeWithIRFlags>(R).getPredicate() !=
7009+
cast<CmpInst>(UI)->getPredicate())
7010+
return true;
70117011
SeenInstrs.insert(UI);
70127012
}
70137013
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ class VPIRFlags {
805805

806806
GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; }
807807

808+
/// Returns true if the recipe has a comparison predicate.
809+
bool hasPredicate() const { return OpType == OperationType::Cmp; }
810+
808811
/// Returns true if the recipe has fast-math flags.
809812
bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; }
810813

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -397,24 +397,32 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
397397
return m_c_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
398398
}
399399

400-
/// ICmp_match is a variant of BinaryRecipe_match that also binds the comparison
401-
/// predicate.
402-
template <typename Op0_t, typename Op1_t> struct ICmp_match {
400+
/// Cmp_match is a variant of BinaryRecipe_match that also binds the comparison
401+
/// predicate. Opcodes must either be Instruction::ICmp or Instruction::FCmp, or
402+
/// both.
403+
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
404+
struct Cmp_match {
405+
static_assert((sizeof...(Opcodes) == 1 || sizeof...(Opcodes) == 2) &&
406+
"Expected one or two opcodes");
407+
static_assert(
408+
((Opcodes == Instruction::ICmp || Opcodes == Instruction::FCmp) && ...) &&
409+
"Expected a compare instruction opcode");
410+
403411
CmpPredicate *Predicate = nullptr;
404412
Op0_t Op0;
405413
Op1_t Op1;
406414

407-
ICmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
415+
Cmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
408416
: Predicate(&Pred), Op0(Op0), Op1(Op1) {}
409-
ICmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
417+
Cmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
410418

411419
bool match(const VPValue *V) const {
412420
auto *DefR = V->getDefiningRecipe();
413421
return DefR && match(DefR);
414422
}
415423

416424
bool match(const VPRecipeBase *V) const {
417-
if (m_Binary<Instruction::ICmp>(Op0, Op1).match(V)) {
425+
if ((m_Binary<Opcodes>(Op0, Op1).match(V) || ...)) {
418426
if (Predicate)
419427
*Predicate = cast<VPRecipeWithIRFlags>(V)->getPredicate();
420428
return true;
@@ -423,38 +431,63 @@ template <typename Op0_t, typename Op1_t> struct ICmp_match {
423431
}
424432
};
425433

426-
/// SpecificICmp_match is a variant of ICmp_match that matches the comparison
434+
/// SpecificCmp_match is a variant of Cmp_match that matches the comparison
427435
/// predicate, instead of binding it.
428-
template <typename Op0_t, typename Op1_t> struct SpecificICmp_match {
436+
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
437+
struct SpecificCmp_match {
429438
const CmpPredicate Predicate;
430439
Op0_t Op0;
431440
Op1_t Op1;
432441

433-
SpecificICmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
442+
SpecificCmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
434443
: Predicate(Pred), Op0(LHS), Op1(RHS) {}
435444

436445
bool match(const VPValue *V) const {
437446
CmpPredicate CurrentPred;
438-
return ICmp_match<Op0_t, Op1_t>(CurrentPred, Op0, Op1).match(V) &&
447+
return Cmp_match<Op0_t, Op1_t, Opcodes...>(CurrentPred, Op0, Op1)
448+
.match(V) &&
439449
CmpPredicate::getMatching(CurrentPred, Predicate);
440450
}
441451
};
442452

443453
template <typename Op0_t, typename Op1_t>
444-
inline ICmp_match<Op0_t, Op1_t> m_ICmp(const Op0_t &Op0, const Op1_t &Op1) {
445-
return ICmp_match<Op0_t, Op1_t>(Op0, Op1);
454+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp> m_ICmp(const Op0_t &Op0,
455+
const Op1_t &Op1) {
456+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Op0, Op1);
446457
}
447458

448459
template <typename Op0_t, typename Op1_t>
449-
inline ICmp_match<Op0_t, Op1_t> m_ICmp(CmpPredicate &Pred, const Op0_t &Op0,
450-
const Op1_t &Op1) {
451-
return ICmp_match<Op0_t, Op1_t>(Pred, Op0, Op1);
460+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp>
461+
m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
462+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Pred, Op0, Op1);
452463
}
453464

454465
template <typename Op0_t, typename Op1_t>
455-
inline SpecificICmp_match<Op0_t, Op1_t>
466+
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>
456467
m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
457-
return SpecificICmp_match<Op0_t, Op1_t>(MatchPred, Op0, Op1);
468+
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>(MatchPred, Op0,
469+
Op1);
470+
}
471+
472+
template <typename Op0_t, typename Op1_t>
473+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
474+
m_Cmp(const Op0_t &Op0, const Op1_t &Op1) {
475+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(Op0,
476+
Op1);
477+
}
478+
479+
template <typename Op0_t, typename Op1_t>
480+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
481+
m_Cmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
482+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
483+
Pred, Op0, Op1);
484+
}
485+
486+
template <typename Op0_t, typename Op1_t>
487+
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
488+
m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
489+
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
490+
MatchPred, Op0, Op1);
458491
}
459492

460493
template <typename Op0_t, typename Op1_t>

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,9 @@ static void scalarizeInstruction(const Instruction *Instr,
29302930
RepRecipe->applyFlags(*Cloned);
29312931
RepRecipe->applyMetadata(*Cloned);
29322932

2933+
if (RepRecipe->hasPredicate())
2934+
cast<CmpInst>(Cloned)->setPredicate(RepRecipe->getPredicate());
2935+
29332936
if (auto DL = RepRecipe->getDebugLoc())
29342937
State.setDebugLocFrom(DL);
29352938

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,33 +1108,31 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
11081108
return Def->replaceAllUsesWith(A);
11091109

11101110
// Try to fold Not into compares by adjusting the predicate in-place.
1111-
if (auto *WideCmp = dyn_cast<VPWidenRecipe>(A)) {
1112-
if ((WideCmp->getOpcode() == Instruction::ICmp ||
1113-
WideCmp->getOpcode() == Instruction::FCmp) &&
1114-
all_of(WideCmp->users(), [&WideCmp](VPUser *U) {
1115-
return match(U, m_CombineOr(m_Not(m_Specific(WideCmp)),
1116-
m_Select(m_Specific(WideCmp),
1117-
m_VPValue(), m_VPValue())));
1111+
CmpPredicate Pred;
1112+
if (match(A, m_Cmp(Pred, m_VPValue(), m_VPValue()))) {
1113+
auto *Cmp = cast<VPRecipeWithIRFlags>(A);
1114+
if (all_of(Cmp->users(), [&Cmp](VPUser *U) {
1115+
return match(U, m_CombineOr(m_Not(m_Specific(Cmp)),
1116+
m_Select(m_Specific(Cmp), m_VPValue(),
1117+
m_VPValue())));
11181118
})) {
1119-
WideCmp->setPredicate(
1120-
CmpInst::getInversePredicate(WideCmp->getPredicate()));
1121-
for (VPUser *U : to_vector(WideCmp->users())) {
1119+
Cmp->setPredicate(CmpInst::getInversePredicate(Pred));
1120+
for (VPUser *U : to_vector(Cmp->users())) {
11221121
auto *R = cast<VPSingleDefRecipe>(U);
1123-
if (match(R, m_Select(m_Specific(WideCmp), m_VPValue(X),
1124-
m_VPValue(Y)))) {
1122+
if (match(R, m_Select(m_Specific(Cmp), m_VPValue(X), m_VPValue(Y)))) {
11251123
// select (cmp pred), x, y -> select (cmp inv_pred), y, x
11261124
R->setOperand(1, Y);
11271125
R->setOperand(2, X);
11281126
} else {
11291127
// not (cmp pred) -> cmp inv_pred
1130-
assert(match(R, m_Not(m_Specific(WideCmp))) && "Unexpected user");
1131-
R->replaceAllUsesWith(WideCmp);
1128+
assert(match(R, m_Not(m_Specific(Cmp))) && "Unexpected user");
1129+
R->replaceAllUsesWith(Cmp);
11321130
}
11331131
}
1134-
// If WideCmp doesn't have a debug location, use the one from the
1135-
// negation, to preserve the location.
1136-
if (!WideCmp->getDebugLoc() && R.getDebugLoc())
1137-
WideCmp->setDebugLoc(R.getDebugLoc());
1132+
// If Cmp doesn't have a debug location, use the one from the negation,
1133+
// to preserve the location.
1134+
if (!Cmp->getDebugLoc() && R.getDebugLoc())
1135+
Cmp->setDebugLoc(R.getDebugLoc());
11381136
}
11391137
}
11401138
}

llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -973,18 +973,16 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
973973
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_ENTRY1:%.*]] = icmp ult i64 1, [[TMP0]]
974974
; TFA_INTERLEAVE-NEXT: br label %[[VECTOR_BODY:.*]]
975975
; TFA_INTERLEAVE: [[VECTOR_BODY]]:
976-
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP19:.*]] ]
977-
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP19]] ]
978-
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP19]] ]
976+
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP18:.*]] ]
977+
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP18]] ]
978+
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP18]] ]
979979
; TFA_INTERLEAVE-NEXT: [[TMP4:%.*]] = load double, ptr [[P2]], align 8
980980
; TFA_INTERLEAVE-NEXT: [[TMP5:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7:[0-9]+]]
981981
; TFA_INTERLEAVE-NEXT: [[TMP6:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7]]
982-
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ogt double [[TMP5]], 0.000000e+00
983-
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ogt double [[TMP6]], 0.000000e+00
984-
; TFA_INTERLEAVE-NEXT: [[TMP9:%.*]] = xor i1 [[TMP7]], true
985-
; TFA_INTERLEAVE-NEXT: [[TMP10:%.*]] = xor i1 [[TMP8]], true
986-
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP9]], i1 false
987-
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP10]], i1 false
982+
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ule double [[TMP5]], 0.000000e+00
983+
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ule double [[TMP6]], 0.000000e+00
984+
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP7]], i1 false
985+
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP8]], i1 false
988986
; TFA_INTERLEAVE-NEXT: [[PREDPHI:%.*]] = select i1 [[TMP11]], double 1.000000e+00, double 0.000000e+00
989987
; TFA_INTERLEAVE-NEXT: [[PREDPHI3:%.*]] = select i1 [[TMP12]], double 1.000000e+00, double 0.000000e+00
990988
; TFA_INTERLEAVE-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], double [[PREDPHI3]], double [[PREDPHI]]
@@ -993,11 +991,11 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
993991
; TFA_INTERLEAVE-NEXT: [[TMP15:%.*]] = xor i1 [[TMP13]], true
994992
; TFA_INTERLEAVE-NEXT: [[TMP16:%.*]] = xor i1 [[TMP14]], true
995993
; TFA_INTERLEAVE-NEXT: [[TMP17:%.*]] = or i1 [[TMP15]], [[TMP16]]
996-
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB18:.*]], label %[[TMP19]]
997-
; TFA_INTERLEAVE: [[BB18]]:
994+
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB16:.*]], label %[[TMP18]]
995+
; TFA_INTERLEAVE: [[BB16]]:
998996
; TFA_INTERLEAVE-NEXT: store double [[SPEC_SELECT]], ptr [[P]], align 8
999-
; TFA_INTERLEAVE-NEXT: br label %[[TMP19]]
1000-
; TFA_INTERLEAVE: [[TMP19]]:
997+
; TFA_INTERLEAVE-NEXT: br label %[[TMP18]]
998+
; TFA_INTERLEAVE: [[TMP18]]:
1001999
; TFA_INTERLEAVE-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 2
10021000
; TFA_INTERLEAVE-NEXT: [[TMP20:%.*]] = add i64 [[INDEX]], 1
10031001
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = icmp ult i64 [[INDEX]], [[TMP3]]

0 commit comments

Comments
 (0)