Skip to content

[llvm] [InstCombine] fold "icmp eq (X + (V - 1)) & -V, X" to "icmp eq 0, (and X, V - 1)" #152851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

pskrgag
Copy link
Contributor

@pskrgag pskrgag commented Aug 9, 2025

This fold optimizes

define i1 @src(i32 %num, i32 %val) {
  %mask = add i32 %val, -1
  %neg = sub nsw i32 0, %val

  %num.biased = add i32 %num, %mask
  %_2.sroa.0.0 = and i32 %num.biased, %neg
  %_0 = icmp eq i32 %_2.sroa.0.0, %num
  ret i1 %_0
}

to

define i1 @tgt(i32 %num, i32 %val) {
  %mask = add i32 %val, -1
  %tmp = and i32 %num, %mask
  %ret = icmp eq i32 %tmp, 0
  ret i1 %ret
}

For power-of-two val.

Observed in real life for following code

pub fn is_aligned(num: usize) -> bool {
    num.next_multiple_of(1 << 12) == num
}

which verifies that num is aligned to 4096.

Alive2 proof https://alive2.llvm.org/ce/z/QisECm

@pskrgag pskrgag requested a review from nikic as a code owner August 9, 2025 12:09
@pskrgag pskrgag requested review from dtcxzyw and removed request for nikic August 9, 2025 12:09
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels Aug 9, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Pavel Skripkin (pskrgag)

Changes

This fold optimizes

define i1 @<!-- -->src(i32 %num, i32 %val) {
  %mask = add i32 %val, -1
  %neg = sub nsw i32 0, %val

  %num.biased = add i32 %num, %mask
  %_2.sroa.0.0 = and i32 %num.biased, %neg
  %_0 = icmp eq i32 %_2.sroa.0.0, %num
  ret i1 %_0
}

to

define i1 @<!-- -->tgt(i32 %num, i32 %val) {
  %mask = add i32 %val, -1
  %tmp = and i32 %num, %mask
  %ret = icmp eq i32 %tmp, 0
  ret i1 %ret
}

For power-of-two val.

Observed in real life for following code

pub fn is_aligned(num: usize) -&gt; bool {
    num.next_multiple_of(1 &lt;&lt; 12) == num
}

which verifies that num is aligned to 4096.

Alive2 proof https://alive2.llvm.org/ce/z/QisECm


Full diff: https://github.com/llvm/llvm-project/pull/152851.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+64)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
  • (modified) llvm/test/Transforms/InstCombine/icmp-add.ll (+90)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index cf94d28100488..722b03eb53f06 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1320,6 +1320,67 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
   return nullptr;
 }
 
+// Fold icmp eq (num + (val - 1)) & -val, num
+//      to
+//      icmp eq 0, (and num, val - 1)
+// For value being power of two
+Instruction *InstCombinerImpl::foldNextMultiply(ICmpInst &Cmp) {
+  Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1);
+  Value *Neg, *Add, *Num, *Mask, *Value;
+  CmpInst::Predicate Pred = Cmp.getPredicate();
+  const APInt *NegConst, *MaskConst, *NumCost;
+
+  if (Pred != ICmpInst::ICMP_EQ)
+    return nullptr;
+
+  // Match num + neg
+  if (!match(Op0, m_And(m_Value(Add), m_Value(Neg))))
+    return nullptr;
+
+  // Match num & mask
+  if (!match(Add, m_Add(m_Value(Num), m_Value(Mask))))
+    return nullptr;
+
+  // Check the constant case
+  if (match(Neg, m_APInt(NegConst)) && match(Mask, m_APInt(MaskConst))) {
+    // Mask + 1 should be a power-of-two
+    if (!(*MaskConst + 1).isPowerOf2())
+      return nullptr;
+
+    // Neg = -(Mask + 1)
+    if (*NegConst != -(*MaskConst + 1))
+      return nullptr;
+  } else {
+    // Match neg = sub 0, val
+    if (!match(Neg, m_Sub(m_Zero(), m_Value(Value))))
+      return nullptr;
+
+    // mask = %val - 1, which can be represented as sub %val, 1 or add %val, -1
+    if (!match(Mask, m_Add(m_Value(Value), m_AllOnes())) &&
+        !match(Mask, m_Sub(m_Value(Value), m_One())))
+      return nullptr;
+
+    // Value should be a known power-of-two.
+    if (!isKnownToBeAPowerOfTwo(Value, false, &Cmp))
+      return nullptr;
+  }
+
+  // Guard against weird special-case where Op1 gets optimized to constant. Leave it constant
+  // fonder.
+  if (match(Op1, m_APInt(NumCost)))
+    return nullptr;
+
+  if (!match(Op1, m_Value(Num)))
+    return nullptr;
+
+  // Create new icmp eq (num & (val - 1)), 0
+  auto NewAnd = Builder.CreateAnd(Num, Mask);
+  auto Zero = llvm::Constant::getNullValue(Num->getType());
+  auto ICmp = Builder.CreateICmp(CmpInst::ICMP_EQ, NewAnd, Zero);
+
+  return replaceInstUsesWith(Cmp, ICmp);
+}
+
 /// Fold icmp Pred X, C.
 /// TODO: This code structure does not make sense. The saturating add fold
 /// should be moved to some other helper and extended as noted below (it is also
@@ -7644,6 +7705,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = foldICmpUsingKnownBits(I))
     return Res;
 
+  if (Instruction *Res = foldNextMultiply(I))
+    return Res;
+
   // Test if the ICmpInst instruction is used exclusively by a select as
   // part of a minimum or maximum operation. If so, refrain from doing
   // any other folding. This helps out other analyses which understand
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c67e27e5b3e7c..5f83cb1b9ae28 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -721,6 +721,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
   Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
   Instruction *foldICmpWithConstant(ICmpInst &Cmp);
+  Instruction *foldNextMultiply(ICmpInst &Cmp);
   Instruction *foldICmpUsingBoolRange(ICmpInst &I);
   Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
   Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp);
diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll
index 1a41c1f3e1045..698619ab8aad1 100644
--- a/llvm/test/Transforms/InstCombine/icmp-add.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-add.ll
@@ -3300,3 +3300,93 @@ entry:
   %cmp = icmp ult i32 %add, 253
   ret i1 %cmp
 }
+
+define i1 @val_is_aligend_sub(i32 %num, i32 %val) {
+; CHECK-LABEL: @val_is_aligend_sub(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call range(i32 1, 33) i32 @llvm.ctpop.i32(i32 [[NUM:%.*]])
+; CHECK-NEXT:    [[POW:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT:    call void @llvm.assume(i1 [[POW]])
+; CHECK-NEXT:    [[NEG:%.*]] = add i32 [[NUM]], -1
+; CHECK-NEXT:    [[_2_SROA_0_0:%.*]] = and i32 [[NUM_BIASED:%.*]], [[NEG]]
+; CHECK-NEXT:    [[_0:%.*]] = icmp eq i32 [[_2_SROA_0_0]], 0
+; CHECK-NEXT:    ret i1 [[_0]]
+;
+  %1 = tail call range(i32 1, 33) i32 @llvm.ctpop.i32(i32 %val)
+  %pow = icmp eq i32 %1, 1
+  call void @llvm.assume(i1 %pow)
+
+  %mask = sub i32 %val, 1
+  %neg = sub nsw i32 0, %val
+
+  %num.biased = add i32 %num, %mask
+  %_2.sroa.0.0 = and i32 %num.biased, %neg
+  %_0 = icmp eq i32 %_2.sroa.0.0, %num
+  ret i1 %_0
+}
+
+define i1 @val_is_aligend_add(i32 %num, i32 %val) {
+; CHECK-LABEL: @val_is_aligend_add(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call range(i32 1, 33) i32 @llvm.ctpop.i32(i32 [[NUM:%.*]])
+; CHECK-NEXT:    [[POW:%.*]] = icmp eq i32 [[TMP1]], 1
+; CHECK-NEXT:    call void @llvm.assume(i1 [[POW]])
+; CHECK-NEXT:    [[NEG:%.*]] = add i32 [[NUM]], -1
+; CHECK-NEXT:    [[_2_SROA_0_0:%.*]] = and i32 [[NUM_BIASED:%.*]], [[NEG]]
+; CHECK-NEXT:    [[_0:%.*]] = icmp eq i32 [[_2_SROA_0_0]], 0
+; CHECK-NEXT:    ret i1 [[_0]]
+;
+  %1 = tail call range(i32 1, 33) i32 @llvm.ctpop.i32(i32 %val)
+  %pow = icmp eq i32 %1, 1
+  call void @llvm.assume(i1 %pow)
+
+  %mask = add i32 %val, -1
+  %neg = sub nsw i32 0, %val
+
+  %num.biased = add i32 %num, %mask
+  %_2.sroa.0.0 = and i32 %num.biased, %neg
+  %_0 = icmp eq i32 %_2.sroa.0.0, %num
+  ret i1 %_0
+}
+
+define i1 @val_is_aligend_const_pow2(i32 %num) {
+; CHECK-LABEL: @val_is_aligend_const_pow2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[NUM:%.*]], 4095
+; CHECK-NEXT:    [[_0:%.*]] = icmp eq i32 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[_0]]
+;
+  %num.biased = add i32 %num, 4095
+  %_2.sroa.0.0 = and i32 %num.biased, -4096
+  %_0 = icmp eq i32 %_2.sroa.0.0, %num
+  ret i1 %_0
+}
+
+; Should not work for non-power-of-two cases
+define i1 @val_is_aligend_const_non-pow2(i32 %num) {
+; CHECK-LABEL: @val_is_aligend_const_non-pow2(
+; CHECK-NEXT:    [[NUM_BIASED:%.*]] = add i32 [[NUM:%.*]], 6
+; CHECK-NEXT:    [[_2_SROA_0_0:%.*]] = and i32 [[NUM_BIASED]], -7
+; CHECK-NEXT:    [[_0:%.*]] = icmp eq i32 [[_2_SROA_0_0]], [[NUM]]
+; CHECK-NEXT:    ret i1 [[_0]]
+;
+  %num.biased = add i32 %num, 6
+  %_2.sroa.0.0 = and i32 %num.biased, -7
+  %_0 = icmp eq i32 %_2.sroa.0.0, %num
+  ret i1 %_0
+}
+
+define i1 @val_is_aligend_non_pow(i32 %num, i32 %val) {
+; CHECK-LABEL: @val_is_aligend_non_pow(
+; CHECK-NEXT:    [[MASK:%.*]] = add i32 [[VAL:%.*]], -1
+; CHECK-NEXT:    [[NEG:%.*]] = sub nsw i32 0, [[VAL]]
+; CHECK-NEXT:    [[NUM_BIASED:%.*]] = add i32 [[NUM:%.*]], [[MASK]]
+; CHECK-NEXT:    [[_2_SROA_0_0:%.*]] = and i32 [[NUM_BIASED]], [[NEG]]
+; CHECK-NEXT:    [[_0:%.*]] = icmp eq i32 [[_2_SROA_0_0]], [[NUM]]
+; CHECK-NEXT:    ret i1 [[_0]]
+;
+  %mask = add i32 %val, -1
+  %neg = sub nsw i32 0, %val
+
+  %num.biased = add i32 %num, %mask
+  %_2.sroa.0.0 = and i32 %num.biased, %neg
+  %_0 = icmp eq i32 %_2.sroa.0.0, %num
+  ret i1 %_0
+}

@dtcxzyw dtcxzyw requested a review from nikic August 9, 2025 18:23
@pskrgag
Copy link
Contributor Author

pskrgag commented Aug 9, 2025

@dtcxzyw Big thanks you for review! Addressed your comments

Copy link

github-actions bot commented Aug 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

if (*NegConst != -(*MaskConst + 1))
return nullptr;
} else {
// Match neg = sub 0, val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt the value of handling variable cases. Let's see the real-world usefulness after fixing the multi-use problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I am also curios to know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path looks unprofitable. See also https://github.com/dtcxzyw/llvm-opt-benchmark/pull/2657/files.
All the power-of-2 values are constants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Dropped the general case.

@pskrgag
Copy link
Contributor Author

pskrgag commented Aug 11, 2025

CI looks unrelated

  ERROR: test_modulelist_deadlock (TestStatusline.TestStatusline.test_modulelist_deadlock)
     Regression test for a deadlock that occurs when the status line is enabled before connecting

@@ -1320,6 +1320,40 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
return nullptr;
}

// Fold icmp eq (num + (val - 1)) & -val, num
// to
// icmp eq 0, (and num, val - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// icmp eq 0, (and num, val - 1)
// icmp eq (and num, val - 1), 0


if (!match(&Cmp, m_c_ICmp(Pred, m_Value(Num),
m_OneUse(m_c_And(m_OneUse(m_c_Add(m_Deferred(Num),
m_Value(Mask))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly match m_LowBitMask here? and then use m_SpecificInt(~*Mask) for m_Value(Neg)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second part didn't work because of crash on Mask de-reference. Changed m_Value(Neg) to m_APInt(Neg) and kept the Neg != ~Mask part

if (!match(Neg, m_APInt(NegConst)) || !match(Mask, m_LowBitMask(MaskConst)))
return nullptr;

// Neg = -(Mask + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Neg = -(Mask + 1)
// Neg = ~Mask

To match implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped comment entirely

%_2.sroa.0.0 = and i32 %num.biased, -7
%_0 = icmp eq i32 %_2.sroa.0.0, %num
ret i1 %_0
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more negative tests, e.g. for non-equality predicate and mismatch between the constants.

Also add a test with ne instead of eq.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added suggested tests + added commute tests, which I missed from previous iteration

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants