From 229ec2822e5c5d959f287eb8f93fbce0007e3b13 Mon Sep 17 00:00:00 2001 From: woruyu <1214539920@qq.com> Date: Fri, 4 Jul 2025 19:20:07 +0800 Subject: [PATCH 1/3] [DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef handling --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 39 ++++++++++++++++--- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 4 ++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2 +- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 ++++++ llvm/lib/Target/X86/X86ISelLowering.cpp | 20 +++++----- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 35322c32a8283..7c5cdbbeb0ca8 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) { return SpecificInt_match(APInt(64, V)); } -inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); } -inline SpecificInt_match m_One() { return m_SpecificInt(1U); } +struct Zero_match { + bool AllowUndefs; + + explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {} + + template + bool match(const MatchContext &, SDValue N) const { + return isZeroOrZeroSplat(N, AllowUndefs); + } +}; + +struct Ones_match { + bool AllowUndefs; + + Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {} + + template bool match(const MatchContext &, SDValue N) { + return isOnesOrOnesSplat(N, AllowUndefs); + } +}; struct AllOnes_match { + bool AllowUndefs; - AllOnes_match() = default; + AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {} template bool match(const MatchContext &, SDValue N) { - return isAllOnesOrAllOnesSplat(N); + return isAllOnesOrAllOnesSplat(N, AllowUndefs); } }; -inline AllOnes_match m_AllOnes() { return AllOnes_match(); } +inline Ones_match m_One(bool AllowUndefs = false) { + return Ones_match(AllowUndefs); +} +inline Zero_match m_Zero(bool AllowUndefs = false) { + return Zero_match(AllowUndefs); +} +inline AllOnes_match m_AllOnes(bool AllowUndefs = false) { + return AllOnes_match(AllowUndefs); +} /// Match true boolean value based on the information provided by /// TargetLowering. @@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) { /// Match a negate as a sub(0, v) template -inline BinaryOpc_match m_Neg(const ValTy &V) { +inline BinaryOpc_match m_Neg(const ValTy &V) { return m_Sub(m_Zero(), V); } diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index a3675eecfea3f..6bfc40afeb55e 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false); /// Does not permit build vector implicit truncation. LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false); +LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false); + +LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false); + /// Return true if \p V is either a integer or FP constant. inline bool isIntOrFPConstant(SDValue V) { return isa(V) || isa(V); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d4ad4d3a09381..f94b3a35652fc 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return V; // (A - B) - 1 -> add (xor B, -1), A - if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One()))) + if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true)))) return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT)); // Look for: diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 2a3c8e2b011ad..d6605c3ec77dd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) { return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth; } +bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) { + N = peekThroughBitcasts(N); + ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); + return C && C->getAPIntValue() == 1; +} + +bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) { + N = peekThroughBitcasts(N); + ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true); + return C && C->isZero(); +} + HandleSDNode::~HandleSDNode() { DropOperands(); } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index afffe51f23a27..7ec666d0b1658 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -57925,22 +57925,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, } } + SDValue X, Y; + // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) // iff X and Y won't overflow. - if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW && - ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) && - ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) { - if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) { - MVT OpVT = Op0.getOperand(1).getSimpleValueType(); - SDValue Sum = - DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0)); - return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum, - getZeroVector(OpVT, Subtarget, DAG, DL)); - } + if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) && + sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) && + DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) { + MVT OpVT = X.getSimpleValueType(); + SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y); + return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum, + getZeroVector(OpVT, Subtarget, DAG, DL)); } if (VT.isVector()) { - SDValue X, Y; EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorElementCount()); From f23969c3e22f7fa4bdbc110bad8b531ec2e6ed54 Mon Sep 17 00:00:00 2001 From: woruyu <1214539920@qq.com> Date: Mon, 7 Jul 2025 18:19:23 +0800 Subject: [PATCH 2/3] test: add MatchZeroOneAllOnes testcase --- .../CodeGen/SelectionDAGPatternMatchTest.cpp | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index baee2868d2d60..5d86d2bc28593 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -825,3 +825,92 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) { EXPECT_FALSE(sd_match( ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value()))); } + +TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { + using namespace SDPatternMatch; + + SDLoc DL; + EVT VT = EVT::getIntegerVT(Context, 32); + + // Scalar constant 0 + SDValue Zero = DAG->getConstant(0, DL, VT); + EXPECT_TRUE(sd_match(Zero, DAG.get(), llvm::SDPatternMatch::m_Zero())); + EXPECT_FALSE(sd_match(Zero, DAG.get(), m_One())); + EXPECT_FALSE(sd_match(Zero, DAG.get(), m_AllOnes())); + + // Scalar constant 1 + SDValue One = DAG->getConstant(1, DL, VT); + EXPECT_FALSE(sd_match(One, DAG.get(), m_Zero())); + EXPECT_TRUE(sd_match(One, DAG.get(), m_One())); + EXPECT_FALSE(sd_match(One, DAG.get(), m_AllOnes())); + + // Scalar constant -1 + SDValue AllOnes = + DAG->getConstant(APInt::getAllOnes(VT.getSizeInBits()), DL, VT); + EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_Zero())); + EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_One())); + EXPECT_TRUE(sd_match(AllOnes, DAG.get(), m_AllOnes())); + + EVT VecF32 = EVT::getVectorVT(Context, MVT::f32, 4); + EVT VecVT = EVT::getVectorVT(Context, MVT::i32, 4); + + // m_Zero: splat vector of 0 → bitcast + { + SDValue SplatVal = DAG->getConstant(0, DL, MVT::i32); + SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_Zero())); + } + + // m_One: splat vector of 1 → bitcast + { + SDValue SplatVal = DAG->getConstant(1, DL, MVT::i32); + SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One())); + } + + // m_AllOnes: splat vector of -1 → bitcast + { + SDValue SplatVal = DAG->getConstant(APInt::getAllOnes(32), DL, MVT::i32); + SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_AllOnes())); + } + + // splat vector with one undef → default should NOT match + SDValue Undef = DAG->getUNDEF(MVT::i32); + + { + // m_Zero: Undef + constant 0 + SDValue Zero = DAG->getConstant(0, DL, MVT::i32); + SmallVector Ops(4, Zero); + Ops[2] = Undef; + SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); + EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_Zero())); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_Zero(true))); + } + + { + // m_One: Undef + constant 1 + SDValue One = DAG->getConstant(1, DL, MVT::i32); + SmallVector Ops(4, One); + Ops[1] = Undef; + SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); + EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One())); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One(true))); + } + + { + // m_AllOnes: Undef + constant -1 + SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, MVT::i32); + SmallVector Ops(4, AllOnes); + Ops[0] = Undef; + SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); + SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); + EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_AllOnes())); + EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_AllOnes(true))); + } +} From 6d6c2502573011350e7ad5555236e06258a8cca3 Mon Sep 17 00:00:00 2001 From: woruyu <1214539920@qq.com> Date: Mon, 7 Jul 2025 20:11:56 +0800 Subject: [PATCH 3/3] fix: modify review for code details --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 6 ++++++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3 ++- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 4 ++-- llvm/lib/Target/X86/X86ISelLowering.cpp | 20 ++++++++++--------- .../CodeGen/SelectionDAGPatternMatchTest.cpp | 17 +++++++--------- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 6bfc40afeb55e..5d9937f832396 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1937,8 +1937,14 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false); /// Does not permit build vector implicit truncation. LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false); +/// Return true if the value is a constant 1 integer or a splatted vector of a +/// constant 1 integer (with no undefs). +/// Does not permit build vector implicit truncation. LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false); +/// Return true if the value is a constant 0 integer or a splatted vector of a +/// constant 0 integer (with no undefs). +/// Does not permit build vector implicit truncation. LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false); /// Return true if \p V is either a integer or FP constant. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index f94b3a35652fc..476d2a7d42a9e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4281,7 +4281,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return V; // (A - B) - 1 -> add (xor B, -1), A - if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true)))) + if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), + m_One(/*AllowUndefs=*/true)))) return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT)); // Look for: diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index d6605c3ec77dd..22a813f7a14e5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12570,9 +12570,9 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) { } bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) { - N = peekThroughBitcasts(N); ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); - return C && C->getAPIntValue() == 1; + return C && APInt::isSameValue(C->getAPIntValue(), + APInt(C->getAPIntValue().getBitWidth(), 1)); } bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) { diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7ec666d0b1658..afffe51f23a27 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -57925,20 +57925,22 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, } } - SDValue X, Y; - // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) // iff X and Y won't overflow. - if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) && - sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) && - DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) { - MVT OpVT = X.getSimpleValueType(); - SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y); - return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum, - getZeroVector(OpVT, Subtarget, DAG, DL)); + if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW && + ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) && + ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) { + if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) { + MVT OpVT = Op0.getOperand(1).getSimpleValueType(); + SDValue Sum = + DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0)); + return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum, + getZeroVector(OpVT, Subtarget, DAG, DL)); + } } if (VT.isVector()) { + SDValue X, Y; EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorElementCount()); diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 5d86d2bc28593..dc531e6013745 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -867,7 +867,7 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { SDValue SplatVal = DAG->getConstant(1, DL, MVT::i32); SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal); SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat); - EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One())); + EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One())); } // m_AllOnes: splat vector of -1 → bitcast @@ -887,9 +887,8 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { SmallVector Ops(4, Zero); Ops[2] = Undef; SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); - SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); - EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_Zero())); - EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_Zero(true))); + EXPECT_FALSE(sd_match(Vec, DAG.get(), m_Zero())); + EXPECT_TRUE(sd_match(Vec, DAG.get(), m_Zero(true))); } { @@ -898,9 +897,8 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { SmallVector Ops(4, One); Ops[1] = Undef; SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); - SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); - EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One())); - EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_One(true))); + EXPECT_FALSE(sd_match(Vec, DAG.get(), m_One())); + EXPECT_TRUE(sd_match(Vec, DAG.get(), m_One(true))); } { @@ -909,8 +907,7 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { SmallVector Ops(4, AllOnes); Ops[0] = Undef; SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops); - SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, Vec); - EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_AllOnes())); - EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_AllOnes(true))); + EXPECT_FALSE(sd_match(Vec, DAG.get(), m_AllOnes())); + EXPECT_TRUE(sd_match(Vec, DAG.get(), m_AllOnes(true))); } }