Skip to content

Conversation

@RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented May 20, 2025

No description provided.

…)), Y) -> (sub Y, (sext (vXi1 X)))" matching.
@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/140731.diff

1 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+9-14)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b5d930ed4f7c3..1ee49a4f8a97a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57886,6 +57886,7 @@ static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
 static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
                           TargetLowering::DAGCombinerInfo &DCI,
                           const X86Subtarget &Subtarget) {
+  using namespace SDPatternMatch;
   EVT VT = N->getValueType(0);
   SDValue Op0 = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
@@ -57925,26 +57926,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
   // generic DAG combine without a legal type check, but adding this there
   // caused regressions.
   if (VT.isVector()) {
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    if (Op0.getOpcode() == ISD::ZERO_EXTEND &&
-        Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
-        TLI.isTypeLegal(Op0.getOperand(0).getValueType())) {
-      SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op0.getOperand(0));
-      return DAG.getNode(ISD::SUB, DL, VT, Op1, SExt);
-    }
-
-    if (Op1.getOpcode() == ISD::ZERO_EXTEND &&
-        Op1.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
-        TLI.isTypeLegal(Op1.getOperand(0).getValueType())) {
-      SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op1.getOperand(0));
-      return DAG.getNode(ISD::SUB, DL, VT, Op0, SExt);
+    SDValue X, Y;
+    EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                  VT.getVectorElementCount());
+    if (DAG.getTargetLoweringInfo().isTypeLegal(BoolVT) &&
+        sd_match(N, m_Add(m_ZExt(m_AllOf(m_SpecificVT(BoolVT), m_Value(X))),
+                          m_Value(Y)))) {
+      SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, X);
+      return DAG.getNode(ISD::SUB, DL, VT, Y, SExt);
     }
   }
 
   // Peephole for 512-bit VPDPBSSD on non-VLX targets.
   // TODO: Should this be part of matchPMADDWD/matchPMADDWD_2?
   if (Subtarget.hasVNNI() && Subtarget.useAVX512Regs() && VT == MVT::v16i32) {
-    using namespace SDPatternMatch;
     SDValue Accum, Lo0, Lo1, Hi0, Hi1;
     if (sd_match(N, m_Add(m_Value(Accum),
                           m_Node(ISD::CONCAT_VECTORS,

@RKSimon RKSimon merged commit 621a5a9 into llvm:main May 20, 2025
13 checks passed
@RKSimon RKSimon deleted the x86-combine-add-sdmatch branch May 20, 2025 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants