diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 6fcd606afaa22..b332da0eb1a49 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -511,15 +511,25 @@ static bool isSplat(ArrayRef VL) { } /// \returns True if \p I is commutative, handles CmpInst and BinaryOperator. -static bool isCommutative(Instruction *I) { +/// For BinaryOperator, it also checks if \p InstWithUses is used in specific +/// patterns that make it effectively commutative (like equality comparisons +/// with zero). +/// In most cases, users should not call this function directly (since \p I and +/// \p InstWithUses are the same). However, when analyzing interchangeable +/// instructions, we need to use the converted opcode along with the original +/// uses. +/// \param I The instruction to check for commutativity +/// \param InstWithUses The instruction whose uses are analyzed for special +/// patterns +static bool isCommutative(Instruction *I, Instruction *InstWithUses) { if (auto *Cmp = dyn_cast(I)) return Cmp->isCommutative(); if (auto *BO = dyn_cast(I)) return BO->isCommutative() || (BO->getOpcode() == Instruction::Sub && - !BO->hasNUsesOrMore(UsesLimit) && + !InstWithUses->hasNUsesOrMore(UsesLimit) && all_of( - BO->uses(), + InstWithUses->uses(), [](const Use &U) { // Commutative, if icmp eq/ne sub, 0 CmpPredicate Pred; @@ -536,14 +546,24 @@ static bool isCommutative(Instruction *I) { Flag->isOne()); })) || (BO->getOpcode() == Instruction::FSub && - !BO->hasNUsesOrMore(UsesLimit) && - all_of(BO->uses(), [](const Use &U) { + !InstWithUses->hasNUsesOrMore(UsesLimit) && + all_of(InstWithUses->uses(), [](const Use &U) { return match(U.getUser(), m_Intrinsic(m_Specific(U.get()))); })); return I->isCommutative(); } +/// This is a helper function to check whether \p I is commutative. +/// This is a convenience wrapper that calls the two-parameter version of +/// isCommutative with the same instruction for both parameters. This is +/// the common case where the instruction being checked for commutativity +/// is the same as the instruction whose uses are analyzed for special +/// patterns (see the two-parameter version above for details). +/// \param I The instruction to check for commutativity +/// \returns true if the instruction is commutative, false otherwise +static bool isCommutative(Instruction *I) { return isCommutative(I, I); } + template static std::optional getInsertExtractIndex(const Value *Inst, unsigned Offset) { @@ -2898,7 +2918,11 @@ class BoUpSLP { continue; } auto [SelectedOp, Ops] = convertTo(cast(V), S); - bool IsInverseOperation = !isCommutative(SelectedOp); + // We cannot check commutativity by the converted instruction + // (SelectedOp) because isCommutative also examines def-use + // relationships. + bool IsInverseOperation = + !isCommutative(SelectedOp, cast(V)); for (unsigned OpIdx : seq(ArgSize)) { bool APO = (OpIdx == 0) ? false : IsInverseOperation; OpsVec[OpIdx][Lane] = {Operands[OpIdx][Lane], APO, false}; diff --git a/llvm/test/Transforms/SLPVectorizer/isCommutative.ll b/llvm/test/Transforms/SLPVectorizer/isCommutative.ll new file mode 100644 index 0000000000000..704ac8295f55b --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/isCommutative.ll @@ -0,0 +1,34 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes=slp-vectorizer -S %s | FileCheck %s + +define i16 @check_isCommutative_with_the_original_source() { +; CHECK-LABEL: @check_isCommutative_with_the_original_source( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND3:%.*]] = select i1 true, i16 1, i16 0 +; CHECK-NEXT: ret i16 [[COND3]] +; +entry: + %sub = sub i16 0, -1 + %cmp = icmp eq i16 %sub, 1 + + %sub1 = sub i16 0, -1 + %cmp2 = icmp eq i16 %sub1, 1 + %cond3 = select i1 %cmp2, i16 1, i16 0 + + %sub5 = sub nsw i16 0, 0 + %cmp6 = icmp eq i16 %sub5, 0 + %cmp9 = icmp eq i16 %sub5, 0 + + %sub12 = sub nsw i16 0, 0 + %cmp13 = icmp eq i16 %sub12, 0 + + %sub16 = sub nsw i16 0, 0 + %cmp17 = icmp eq i16 %sub16, 0 + + %sub20 = sub nsw i16 0, 0 + %cmp21 = icmp eq i16 %sub20, 0 + %cmp24 = icmp eq i16 %sub20, 0 + + ret i16 %cond3 +} +