-
Notifications
You must be signed in to change notification settings - Fork 15k
[PatternMatch] Introduce match functor (NFC) #159386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
A common idiom is the usage of the PatternMatch match function within a functional algorithm like all_of. Introduce a match functor to shorten this idiom. Co-authored-by: Luke Lau <[email protected]>
|
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Ramkumar Ramachandra (artagnon) ChangesA common idiom is the usage of the PatternMatch match function within a functional algorithm like all_of. Introduce a match functor to shorten this idiom. Full diff: https://github.com/llvm/llvm-project/pull/159386.diff 11 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 2cb78904dd799..a16776c62f32b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -50,6 +50,19 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val, typename Pattern> struct MatchFunctor {
+ const Pattern &P;
+ MatchFunctor(const Pattern &P) : P(P) {}
+ bool operator()(Val *V) const { return P.match(V); }
+};
+
+/// A match functor that can be used as a UnaryPredicate in functional
+/// algorithms like all_of.
+template <typename Val = const Value, typename Pattern>
+MatchFunctor<Val, Pattern> match_fn(const Pattern &P) {
+ return P;
+}
+
template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
return P.match(Mask);
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 7bff13d59528c..ff7f1aefebc50 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// All-zero GEP is a no-op, unless it performs a vector splat.
- if (Ptr->getType() == GEPTy &&
- all_of(Indices, [](const auto *V) { return match(V, m_Zero()); }))
+ if (Ptr->getType() == GEPTy && all_of(Indices, match_fn(m_Zero())))
return Ptr;
// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
- if (isa<PoisonValue>(Ptr) ||
- any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); }))
+ if (isa<PoisonValue>(Ptr) || any_of(Indices, match_fn(m_Poison())))
return PoisonValue::get(GEPTy);
// getelementptr undef, idx -> undef
@@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 &&
- all_of(Indices.drop_back(1),
- [](Value *Idx) { return match(Idx, m_Zero()); })) {
+ all_of(Indices.drop_back(1), match_fn(m_Zero()))) {
unsigned IdxWidth =
Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace());
if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) {
@@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// Check to see if this is constant foldable.
- if (!isa<Constant>(Ptr) ||
- !all_of(Indices, [](Value *V) { return isa<Constant>(V); }))
+ if (!isa<Constant>(Ptr) || !all_of(Indices, match_fn(m_Constant())))
return nullptr;
if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
@@ -5662,7 +5658,7 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
RoundingMode Rounding) {
// Poison is independent of anything else. It always propagates from an
// operand to a math result.
- if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
+ if (any_of(Ops, match_fn(m_Poison())))
return PoisonValue::get(Ops[0]->getType());
for (Value *V : Ops) {
@@ -7126,7 +7122,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
switch (I->getOpcode()) {
default:
- if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(NewOps, match_fn(m_Constant()))) {
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 129823e0e98a3..e9e9a127aae92 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -250,9 +250,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
}
bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
- return !I->user_empty() && all_of(I->users(), [](const User *U) {
- return match(U, m_ICmp(m_Value(), m_Zero()));
- });
+ return !I->user_empty() &&
+ all_of(I->users(), match_fn(m_ICmp(m_Value(), m_Zero())));
}
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index e3ded12a1847b..a6a9b5058ad94 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -312,10 +312,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
continue;
}
if (auto *BI = dyn_cast<BinaryOperator>(User)) {
- if (!BI->user_empty() && all_of(BI->users(), [](auto *U) {
- auto *SVI = dyn_cast<ShuffleVectorInst>(U);
- return SVI && isa<UndefValue>(SVI->getOperand(1));
- })) {
+ using namespace PatternMatch;
+ if (!BI->user_empty() &&
+ all_of(BI->users(), match_fn(m_Shuffle(m_Value(), m_Undef())))) {
for (auto *SVI : BI->users())
BinOpShuffles.insert(cast<ShuffleVectorInst>(SVI));
continue;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 00951fde0cf8a..d1ca0a6a393c5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2354,12 +2354,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
// a pure negation used by a select that looks like abs/nabs.
bool IsNegation = match(Op0, m_ZeroInt());
- if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
- const Instruction *UI = dyn_cast<Instruction>(U);
- if (!UI)
- return false;
- return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
- })) {
+ if (!IsNegation || none_of(I.users(), match_fn(m_c_Select(m_Specific(Op1),
+ m_Specific(&I))))) {
if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
I.hasNoSignedWrap(),
Op1, *this))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 17cf4154f8dbd..6ad493772d170 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1416,9 +1416,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) {
// At least 1 operand must be a shuffle with 1 use because we are creating 2
// instructions.
- if (none_of(II->args(), [](Value *V) {
- return isa<ShuffleVectorInst>(V) && V->hasOneUse();
- }))
+ if (none_of(II->args(), match_fn(m_OneUse(m_Shuffle(m_Value(), m_Value())))))
return nullptr;
// See if all arguments are shuffled with the same mask.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 99ea04816681c..250aa5e073fa4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1340,7 +1340,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
- if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(Phi->operands(), match_fn(m_Constant()))) {
SmallVector<Constant *> Ops;
for (Value *V : Phi->incoming_values()) {
Constant *Res =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 15e7172c6ce12..1349f44031adf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -340,7 +340,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
// convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] )
// Make sure all uses of phi are ptr2int.
- if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); }))
+ if (!all_of(PN.users(), match_fn(m_PtrToInt(m_Value()))))
return nullptr;
// Iterating over all operands to check presence of target pointers for
@@ -1299,7 +1299,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// \ /
// phi [v1] [v2]
// Make sure all inputs are constants.
- if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
+ if (!all_of(PN.operands(), match_fn(m_ConstantInt())))
return nullptr;
BasicBlock *BB = PN.getParent();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8f9d0bf6240d5..4ea75409252bd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3257,7 +3257,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
// Profitability check - avoid increasing instruction count.
if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
- [](Value *V) { return V->hasOneUse(); }))
+ match_fn(m_OneUse(m_Value()))))
return nullptr;
// The appropriate hand of the outermost `select` must be a select itself.
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 092a0fb264c28..bab1f2a90a8fd 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -437,10 +437,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
// potentially happen in other passes where instructions are being moved
// across that edge.
bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) {
- return llvm::any_of(*BB, [](Instruction &I) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- return II && II->getIntrinsicID() == Intrinsic::coro_suspend;
- });
+ using namespace PatternMatch;
+ return any_of(make_pointer_range(*BB),
+ match_fn(m_Intrinsic<Intrinsic::coro_suspend>()));
});
MemorySSAUpdater MSSAU(MSSA);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8aafe14c0cbe0..27ddd1ebff887 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -22374,10 +22374,10 @@ void BoUpSLP::computeMinimumValueSizes() {
IsTruncRoot = true;
}
bool IsSignedCmp = false;
- if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) {
- return match(V, m_SMin(m_Value(), m_Value())) ||
- match(V, m_SMax(m_Value(), m_Value()));
- }))
+ if (UserIgnoreList &&
+ all_of(*UserIgnoreList,
+ match_fn(m_CombineOr(m_SMin(m_Value(), m_Value()),
+ m_SMax(m_Value(), m_Value())))))
IsSignedCmp = true;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking this up, it's very satisfying to see the lambdas eliminated :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice, thanks!
A common idiom is the usage of the PatternMatch match function within a functional algorithm like all_of. Introduce a match functor to shorten this idiom.