Skip to content

Conversation

@goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn requested a review from nikic as a code owner January 10, 2025 22:47
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels Jan 10, 2025
@goldsteinn goldsteinn requested a review from dtcxzyw January 10, 2025 22:48
@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes

https://alive2.llvm.org/ce/z/zGwUBp

This came up when trying to implment:
(ctlz Pow2) -> (sub/xor BW - 1, Log2(Pow2))

https://github.com/dtcxzyw/llvm-opt-benchmark/pull/1944/files#diff-b13a246c0599d6b2255fae8707bb4e36bf28221ff67f7ad54d169f7bd5ba4e22R13996


Full diff: https://github.com/llvm/llvm-project/pull/122542.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+5)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+10-10)
  • (modified) llvm/test/Transforms/InstCombine/sub-xor.ll (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/sub.ll (+3-3)
  • (modified) llvm/test/Transforms/InstCombine/vec_demanded_elts-inseltpoison.ll (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/vec_demanded_elts.ll (+1-1)
  • (modified) llvm/test/Transforms/PhaseOrdering/X86/pr88239.ll (+7-6)
  • (modified) llvm/test/Transforms/PhaseOrdering/scev-custom-dl.ll (+4-4)
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index fa6b60cba15aaf..1dae2c82dae2e4 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -236,6 +236,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
     return isFreeToInvert(V, WillInvertAllUses, Unused);
   }
 
+  /// If `Not` is true, returns true if V is a negative power of 2 or zero.
+  /// If `Not` is false, returns true if V is a Mask or zero.
+  bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
+                    unsigned Depth = 0);
+
   /// Given i1 V, can every user of V be freely adapted if V is changed to !V ?
   /// InstCombine's freelyInvertAllUsersOf() must be kept in sync with this fn.
   /// NOTE: for Instructions only!
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 73876d00e73a7c..d91437e8aec9e1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2780,6 +2780,12 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
   if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
     return Res;
 
+  // Canonicalize (sub nuw Mask, X) -> (xor Mask, X)
+  if (I.hasNoUnsignedWrap() &&
+      isMaskOrZero(Op0, /*Not=*/false,
+                   getSimplifyQuery().getWithInstruction(&I)))
+    return BinaryOperator::CreateXor(Op0, Op1);
+
   return TryToNarrowDeduceFlags();
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 2e457257599493..bc3db5b065df42 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4326,8 +4326,8 @@ Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI,
 }
 
 // Returns whether V is a Mask ((X + 1) & X == 0) or ~Mask (-Pow2OrZero)
-static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
-                         unsigned Depth = 0) {
+bool InstCombiner::isMaskOrZero(const Value *V, bool Not,
+                                const SimplifyQuery &Q, unsigned Depth) {
   if (Not ? match(V, m_NegatedPower2OrZero()) : match(V, m_LowBitMaskOrZero()))
     return true;
   if (V->getType()->getScalarSizeInBits() == 1)
@@ -4381,14 +4381,14 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
   case Instruction::Add:
     // Pow2 - 1 is a Mask.
     if (!Not && match(I->getOperand(1), m_AllOnes()))
-      return isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true,
-                                    Depth, Q.AC, Q.CxtI, Q.DT);
+      return ::isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true,
+                                      Depth, Q.AC, Q.CxtI, Q.DT);
     break;
   case Instruction::Sub:
     // -Pow2 is a ~Mask.
     if (Not && match(I->getOperand(0), m_Zero()))
-      return isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true,
-                                    Depth, Q.AC, Q.CxtI, Q.DT);
+      return ::isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true,
+                                      Depth, Q.AC, Q.CxtI, Q.DT);
     break;
   case Instruction::Call: {
     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
@@ -4497,13 +4497,13 @@ static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
     if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) {
       X = Op1;
       // Look for: x & Mask pred x
-      if (isMaskOrZero(M, /*Not=*/false, Q)) {
+      if (IC.isMaskOrZero(M, /*Not=*/false, Q)) {
         return !ICmpInst::isSigned(Pred) ||
                (match(M, m_NonNegative()) || isKnownNonNegative(M, Q));
       }
 
       // Look for: x & ~Mask pred ~Mask
-      if (isMaskOrZero(X, /*Not=*/true, Q)) {
+      if (IC.isMaskOrZero(X, /*Not=*/true, Q)) {
         return !ICmpInst::isSigned(Pred) || isKnownNonZero(X, Q);
       }
       return false;
@@ -4513,7 +4513,7 @@ static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
 
       auto Check = [&]() {
         // Look for: ~x | Mask == -1
-        if (isMaskOrZero(M, /*Not=*/false, Q)) {
+        if (IC.isMaskOrZero(M, /*Not=*/false, Q)) {
           if (Value *NotX =
                   IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) {
             X = NotX;
@@ -4531,7 +4531,7 @@ static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
         match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
       auto Check = [&]() {
         // Look for: x & ~Mask == 0
-        if (isMaskOrZero(M, /*Not=*/true, Q)) {
+        if (IC.isMaskOrZero(M, /*Not=*/true, Q)) {
           if (Value *NotM =
                   IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) {
             M = NotM;
diff --git a/llvm/test/Transforms/InstCombine/sub-xor.ll b/llvm/test/Transforms/InstCombine/sub-xor.ll
index a4135e0b514532..180de6d2f88282 100644
--- a/llvm/test/Transforms/InstCombine/sub-xor.ll
+++ b/llvm/test/Transforms/InstCombine/sub-xor.ll
@@ -166,7 +166,7 @@ define i32 @xor_dominating_cond(i32 %x) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[X:%.*]], 256
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN:%.*]], label [[IF_END:%.*]]
 ; CHECK:       if.then:
-; CHECK-NEXT:    [[A:%.*]] = sub nuw nsw i32 255, [[X]]
+; CHECK-NEXT:    [[A:%.*]] = xor i32 [[X]], 255
 ; CHECK-NEXT:    ret i32 [[A]]
 ; CHECK:       if.end:
 ; CHECK-NEXT:    ret i32 [[X]]
diff --git a/llvm/test/Transforms/InstCombine/sub.ll b/llvm/test/Transforms/InstCombine/sub.ll
index e89419d1f3838a..ebda52a1354953 100644
--- a/llvm/test/Transforms/InstCombine/sub.ll
+++ b/llvm/test/Transforms/InstCombine/sub.ll
@@ -2204,9 +2204,9 @@ define i8 @shrink_sub_from_constant_lowbits2(i8 %x) {
 
 define <2 x i8> @shrink_sub_from_constant_lowbits3(<2 x i8> %x) {
 ; CHECK-LABEL: @shrink_sub_from_constant_lowbits3(
-; CHECK-NEXT:    [[X0000:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 4)
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw <2 x i8> splat (i8 24), [[X0000]]
-; CHECK-NEXT:    [[R:%.*]] = lshr exact <2 x i8> [[SUB]], splat (i8 3)
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1)
+; CHECK-NEXT:    [[SUB:%.*]] = and <2 x i8> [[TMP1]], splat (i8 30)
+; CHECK-NEXT:    [[R:%.*]] = xor <2 x i8> [[SUB]], splat (i8 3)
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %x0000 = shl <2 x i8> %x, <i8 4, i8 4>     ; 4 low bits are known zero
diff --git a/llvm/test/Transforms/InstCombine/vec_demanded_elts-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_demanded_elts-inseltpoison.ll
index a240dfe7d271a9..8d0e340b77a10c 100644
--- a/llvm/test/Transforms/InstCombine/vec_demanded_elts-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_demanded_elts-inseltpoison.ll
@@ -215,7 +215,7 @@ define <3 x i8> @shuf_add(<3 x i8> %x) {
 
 define <3 x i8> @shuf_sub(<3 x i8> %x) {
 ; CHECK-LABEL: @shuf_sub(
-; CHECK-NEXT:    [[BO:%.*]] = sub nuw <3 x i8> <i8 1, i8 poison, i8 3>, [[X:%.*]]
+; CHECK-NEXT:    [[BO:%.*]] = xor <3 x i8> [[X:%.*]], <i8 1, i8 poison, i8 3>
 ; CHECK-NEXT:    [[R:%.*]] = shufflevector <3 x i8> [[BO]], <3 x i8> poison, <3 x i32> <i32 0, i32 poison, i32 2>
 ; CHECK-NEXT:    ret <3 x i8> [[R]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/vec_demanded_elts.ll b/llvm/test/Transforms/InstCombine/vec_demanded_elts.ll
index ee7ef9955e643f..f9f7017f1bd70d 100644
--- a/llvm/test/Transforms/InstCombine/vec_demanded_elts.ll
+++ b/llvm/test/Transforms/InstCombine/vec_demanded_elts.ll
@@ -218,7 +218,7 @@ define <3 x i8> @shuf_add(<3 x i8> %x) {
 
 define <3 x i8> @shuf_sub(<3 x i8> %x) {
 ; CHECK-LABEL: @shuf_sub(
-; CHECK-NEXT:    [[BO:%.*]] = sub nuw <3 x i8> <i8 1, i8 poison, i8 3>, [[X:%.*]]
+; CHECK-NEXT:    [[BO:%.*]] = xor <3 x i8> [[X:%.*]], <i8 1, i8 poison, i8 3>
 ; CHECK-NEXT:    [[R:%.*]] = shufflevector <3 x i8> [[BO]], <3 x i8> poison, <3 x i32> <i32 0, i32 poison, i32 2>
 ; CHECK-NEXT:    ret <3 x i8> [[R]]
 ;
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/pr88239.ll b/llvm/test/Transforms/PhaseOrdering/X86/pr88239.ll
index b3625094f07ea1..a6c66ff40d8ac6 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/pr88239.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/pr88239.ll
@@ -8,18 +8,19 @@ define void @foo(ptr noalias noundef %0, ptr noalias noundef %1) optsize {
 ; CHECK-LABEL: define void @foo(
 ; CHECK-SAME: ptr noalias nocapture noundef readonly [[TMP0:%.*]], ptr noalias nocapture noundef writeonly [[TMP1:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  vector.ph:
-; CHECK-NEXT:    [[INVARIANT_GEP:%.*]] = getelementptr i8, ptr [[TMP0]], i64 -28
 ; CHECK-NEXT:    br label [[TMP4:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[TMP2:%.*]] ], [ [[INDVARS_IV_NEXT:%.*]], [[TMP4]] ]
-; CHECK-NEXT:    [[TMP3:%.*]] = sub nuw nsw i64 255, [[INDVARS_IV]]
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[INVARIANT_GEP]], i64 [[TMP3]]
-; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = load <8 x i32>, ptr [[GEP]], align 4
+; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, [[TMP2]] ], [ [[VEC_IND_NEXT:%.*]], [[TMP4]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = and <8 x i64> [[VEC_IND]], splat (i64 4294967295)
+; CHECK-NEXT:    [[TMP3:%.*]] = xor <8 x i64> [[TMP6]], splat (i64 255)
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i32, ptr [[TMP0]], <8 x i64> [[TMP3]]
+; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = tail call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[TMP7]], i32 4, <8 x i1> splat (i1 true), <8 x i32> poison)
 ; CHECK-NEXT:    [[TMP5:%.*]] = add nsw <8 x i32> [[WIDE_MASKED_GATHER]], splat (i32 5)
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <8 x i32> [[TMP5]], <8 x i32> poison, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds nuw i32, ptr [[TMP1]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    store <8 x i32> [[TMP6]], ptr [[TMP10]], align 4
+; CHECK-NEXT:    store <8 x i32> [[TMP5]], ptr [[TMP10]], align 4
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw i64 [[INDVARS_IV]], 8
+; CHECK-NEXT:    [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], splat (i64 8)
 ; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 256
 ; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[MIDDLE_BLOCK:%.*]], label [[TMP4]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
diff --git a/llvm/test/Transforms/PhaseOrdering/scev-custom-dl.ll b/llvm/test/Transforms/PhaseOrdering/scev-custom-dl.ll
index 4a19940b6f874f..181669165220e7 100644
--- a/llvm/test/Transforms/PhaseOrdering/scev-custom-dl.ll
+++ b/llvm/test/Transforms/PhaseOrdering/scev-custom-dl.ll
@@ -141,14 +141,14 @@ define i32 @test_loop_idiom_recogize(i32 %x, i32 %y, ptr %lam, ptr %alp) nounwin
 ; CHECK-NEXT:  Classifying expressions for: @test_loop_idiom_recogize
 ; CHECK-NEXT:    %indvar = phi i32 [ 0, %bb1.thread ], [ %indvar.next, %bb1 ]
 ; CHECK-NEXT:    --> {0,+,1}<nuw><nsw><%bb1> U: [0,256) S: [0,256) Exits: 255 LoopDispositions: { %bb1: Computable }
-; CHECK-NEXT:    %i.0.reg2mem.0 = sub nuw nsw i32 255, %indvar
-; CHECK-NEXT:    --> {255,+,-1}<nsw><%bb1> U: [0,256) S: [0,256) Exits: 0 LoopDispositions: { %bb1: Computable }
+; CHECK-NEXT:    %i.0.reg2mem.0 = xor i32 %indvar, 255
+; CHECK-NEXT:    --> %i.0.reg2mem.0 U: [0,-2147483648) S: [0,-2147483648) Exits: 0 LoopDispositions: { %bb1: Variant }
 ; CHECK-NEXT:    %0 = getelementptr i32, ptr %alp, i32 %i.0.reg2mem.0
-; CHECK-NEXT:    --> {(1020 + %alp),+,-4}<nw><%bb1> U: full-set S: full-set Exits: %alp LoopDispositions: { %bb1: Computable }
+; CHECK-NEXT:    --> ((4 * %i.0.reg2mem.0) + %alp) U: full-set S: full-set Exits: %alp LoopDispositions: { %bb1: Variant }
 ; CHECK-NEXT:    %1 = load i32, ptr %0, align 4
 ; CHECK-NEXT:    --> %1 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %bb1: Variant }
 ; CHECK-NEXT:    %2 = getelementptr i32, ptr %lam, i32 %i.0.reg2mem.0
-; CHECK-NEXT:    --> {(1020 + %lam),+,-4}<nw><%bb1> U: full-set S: full-set Exits: %lam LoopDispositions: { %bb1: Computable }
+; CHECK-NEXT:    --> ((4 * %i.0.reg2mem.0) + %lam) U: full-set S: full-set Exits: %lam LoopDispositions: { %bb1: Variant }
 ; CHECK-NEXT:    %indvar.next = add nuw nsw i32 %indvar, 1
 ; CHECK-NEXT:    --> {1,+,1}<nuw><nsw><%bb1> U: [1,257) S: [1,257) Exits: 256 LoopDispositions: { %bb1: Computable }
 ; CHECK-NEXT:    %tmp10 = mul i32 %x, 255

@goldsteinn
Copy link
Contributor Author

Note, it might be preferable to just handle the simple case in InstSimplify.

@goldsteinn
Copy link
Contributor Author

The inability to fold the xor version w/ add leads to many regressions. OTOH there are definetly some positive cases.

I think we want to either just make this a simplification or we need some IR flag indicating the equivilence.

@goldsteinn
Copy link
Contributor Author

The inability to fold the xor version w/ add leads to many regressions. OTOH there are definetly some positive cases.

I think we want to either just make this a simplification or we need some IR flag indicating the equivilence.

Or just update folds using xor to also accept sub nuw. Let me know which of these sounds most reasonable.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 11, 2025

Or just update folds using xor to also accept sub nuw. Let me know which of these sounds most reasonable.

We can start with adding a matcher for both Xor and sub nuw MaskC, X.

we need some IR flag indicating the equivilence.

The downside is that we have to fight with more flag preservation bugs :(

@goldsteinn
Copy link
Contributor Author

Or just update folds using xor to also accept sub nuw. Let me know which of these sounds most reasonable.

We can start with adding a matcher for both Xor and sub nuw MaskC, X.

Yeah, I guess something like m_XorLike.

we need some IR flag indicating the equivilence.

The downside is that we have to fight with more flag preservation bugs :(

Yeah, I don't imagine this is a profitable enough case to make it worth it.

@nikic
Copy link
Contributor

nikic commented Jan 11, 2025

Why is your ctlz fold generating xor in the first place, instead of sub? Your xor implementation is a miscompile for non-pow2 bws, sub is always correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants