Skip to content

Commit 60e8915

Browse files
committed
[InstCombine] Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x)))
`(ctpop (not x))` <-> `(sub nuw nsw BitWidth(x), (ctpop x))`. The `sub` expression can sometimes be constant folded depending on the use case of `(ctpop (not x))`. This patch adds fold for the following cases: `(add/sub/disjoint_or C, (ctpop (not x))` -> `(add/sub/disjoint_or C', (ctpop x))` `(cmp pred C, (ctpop (not x))` -> `(cmp swapped_pred C', (ctpop x))` Where `C'` depends on how we constant fold `C` with `BitWidth(x)` for the given opcode. Proofs: https://alive2.llvm.org/ce/z/qUgfF3 Closes #77859
1 parent 73863a4 commit 60e8915

File tree

6 files changed

+125
-31
lines changed

6 files changed

+125
-31
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
16831683
}
16841684
}
16851685

1686+
if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
1687+
return R;
1688+
16861689
// TODO(jingyue): Consider willNotOverflowSignedAdd and
16871690
// willNotOverflowUnsignedAdd to reduce the number of invocations of
16881691
// computeKnownBits.
@@ -2445,6 +2448,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
24452448
}
24462449
}
24472450

2451+
if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
2452+
return R;
2453+
24482454
if (Instruction *R = foldSubOfMinMax(I, Builder))
24492455
return R;
24502456

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3398,6 +3398,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
33983398
if (Instruction *R = foldBinOpShiftWithShift(I))
33993399
return R;
34003400

3401+
if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
3402+
return R;
3403+
34013404
Value *X, *Y;
34023405
const APInt *CV;
34033406
if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,9 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
13231323
return replaceInstUsesWith(Cmp, NewPhi);
13241324
}
13251325

1326+
if (Instruction *R = tryFoldInstWithCtpopWithNot(&Cmp))
1327+
return R;
1328+
13261329
return nullptr;
13271330
}
13281331

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
505505
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
506506
Value *RHS);
507507

508+
// If `I` has operand `(ctpop (not x))`, fold `I` with `(sub nuw nsw
509+
// BitWidth(x), (ctpop x))`.
510+
Instruction *tryFoldInstWithCtpopWithNot(Instruction *I);
511+
508512
// (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
509513
// -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C)
510514
// (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt))

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,93 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
740740
return RetVal;
741741
}
742742

743+
// If `I` has one Const operand and the other matches `(ctpop (not x))`,
744+
// replace `(ctpop (not x))` with `(sub nuw nsw BitWidth(x), (ctpop x))`.
745+
// This is only useful is the new subtract can fold so we only handle the
746+
// following cases:
747+
// 1) (add/sub/disjoint_or C, (ctpop (not x))
748+
// -> (add/sub/disjoint_or C', (ctpop x))
749+
// 1) (cmp pred C, (ctpop (not x))
750+
// -> (cmp pred C', (ctpop x))
751+
Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
752+
unsigned Opc = I->getOpcode();
753+
unsigned ConstIdx = 1;
754+
switch (Opc) {
755+
default:
756+
return nullptr;
757+
// (ctpop (not x)) <-> (sub nuw nsw BitWidth(x) - (ctpop x))
758+
// We can fold the BitWidth(x) with add/sub/icmp as long the other operand
759+
// is constant.
760+
case Instruction::Sub:
761+
ConstIdx = 0;
762+
break;
763+
case Instruction::ICmp:
764+
// Signed predicates aren't correct in some edge cases like for i2 types, as
765+
// well since (ctpop x) is known [0, log2(BitWidth(x))] almost all signed
766+
// comparisons against it are simplfied to unsigned.
767+
if (cast<ICmpInst>(I)->isSigned())
768+
return nullptr;
769+
break;
770+
case Instruction::Or:
771+
if (!match(I, m_DisjointOr(m_Value(), m_Value())))
772+
return nullptr;
773+
[[fallthrough]];
774+
case Instruction::Add:
775+
break;
776+
}
777+
778+
Value *Op;
779+
// Find ctpop.
780+
if (!match(I->getOperand(1 - ConstIdx),
781+
m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(Op)))))
782+
return nullptr;
783+
784+
Constant *C;
785+
// Check other operand is ImmConstant.
786+
if (!match(I->getOperand(ConstIdx), m_ImmConstant(C)))
787+
return nullptr;
788+
789+
Type *Ty = Op->getType();
790+
Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
791+
// Need extra check for icmp. Note if this check is true, it generally means
792+
// the icmp will simplify to true/false.
793+
if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
794+
!ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
795+
return nullptr;
796+
797+
// Check we can invert `(not x)` for free.
798+
bool Consumes = false;
799+
if (!isFreeToInvert(Op, Op->hasOneUse(), Consumes) || !Consumes)
800+
return nullptr;
801+
Value *NotOp = getFreelyInverted(Op, Op->hasOneUse(), &Builder);
802+
assert(NotOp != nullptr &&
803+
"Desync between isFreeToInvert and getFreelyInverted");
804+
805+
Value *CtpopOfNotOp = Builder.CreateIntrinsic(Ty, Intrinsic::ctpop, NotOp);
806+
807+
Value *R = nullptr;
808+
809+
// Do the transformation here to avoid potentially introducing an infinite
810+
// loop.
811+
switch (Opc) {
812+
case Instruction::Sub:
813+
R = Builder.CreateAdd(CtpopOfNotOp, ConstantExpr::getSub(C, BitWidthC));
814+
break;
815+
case Instruction::Or:
816+
case Instruction::Add:
817+
R = Builder.CreateSub(ConstantExpr::getAdd(C, BitWidthC), CtpopOfNotOp);
818+
break;
819+
case Instruction::ICmp:
820+
R = Builder.CreateICmp(cast<ICmpInst>(I)->getSwappedPredicate(),
821+
CtpopOfNotOp, ConstantExpr::getSub(BitWidthC, C));
822+
break;
823+
default:
824+
llvm_unreachable("Unhandled Opcode");
825+
}
826+
assert(R != nullptr);
827+
return replaceInstUsesWith(*I, R);
828+
}
829+
743830
// (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
744831
// IFF
745832
// 1) the logic_shifts match

llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ declare <2 x i8> @llvm.ctpop.v2i8(<2 x i8>)
88

99
define i8 @fold_sub_c_ctpop(i8 %x) {
1010
; CHECK-LABEL: @fold_sub_c_ctpop(
11-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
12-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0:![0-9]+]]
13-
; CHECK-NEXT: [[R:%.*]] = sub nuw nsw i8 12, [[CNT]]
11+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0:![0-9]+]]
12+
; CHECK-NEXT: [[R:%.*]] = add nuw nsw i8 [[TMP1]], 4
1413
; CHECK-NEXT: ret i8 [[R]]
1514
;
1615
%nx = xor i8 %x, -1
@@ -34,9 +33,8 @@ define i8 @fold_sub_var_ctpop_fail(i8 %x, i8 %y) {
3433

3534
define <2 x i8> @fold_sub_ctpop_c(<2 x i8> %x) {
3635
; CHECK-LABEL: @fold_sub_ctpop_c(
37-
; CHECK-NEXT: [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
38-
; CHECK-NEXT: [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
39-
; CHECK-NEXT: [[R:%.*]] = add nuw nsw <2 x i8> [[CNT]], <i8 -63, i8 -64>
36+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
37+
; CHECK-NEXT: [[R:%.*]] = sub nuw nsw <2 x i8> <i8 -55, i8 -56>, [[TMP1]]
4038
; CHECK-NEXT: ret <2 x i8> [[R]]
4139
;
4240
%nx = xor <2 x i8> %x, <i8 -1, i8 -1>
@@ -47,9 +45,8 @@ define <2 x i8> @fold_sub_ctpop_c(<2 x i8> %x) {
4745

4846
define i8 @fold_add_ctpop_c(i8 %x) {
4947
; CHECK-LABEL: @fold_add_ctpop_c(
50-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
51-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
52-
; CHECK-NEXT: [[R:%.*]] = add nuw nsw i8 [[CNT]], 63
48+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
49+
; CHECK-NEXT: [[R:%.*]] = sub nuw nsw i8 71, [[TMP1]]
5350
; CHECK-NEXT: ret i8 [[R]]
5451
;
5552
%nx = xor i8 %x, -1
@@ -60,9 +57,8 @@ define i8 @fold_add_ctpop_c(i8 %x) {
6057

6158
define i8 @fold_distjoint_or_ctpop_c(i8 %x) {
6259
; CHECK-LABEL: @fold_distjoint_or_ctpop_c(
63-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
64-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
65-
; CHECK-NEXT: [[R:%.*]] = or disjoint i8 [[CNT]], 64
60+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
61+
; CHECK-NEXT: [[R:%.*]] = sub nuw nsw i8 72, [[TMP1]]
6662
; CHECK-NEXT: ret i8 [[R]]
6763
;
6864
%nx = xor i8 %x, -1
@@ -109,9 +105,8 @@ define i1 @fold_icmp_sgt_ctpop_c_i2_fail(i2 %x, i2 %C) {
109105

110106
define i1 @fold_cmp_eq_ctpop_c(i8 %x) {
111107
; CHECK-LABEL: @fold_cmp_eq_ctpop_c(
112-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
113-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
114-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[CNT]], 2
108+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
109+
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[TMP1]], 6
115110
; CHECK-NEXT: ret i1 [[R]]
116111
;
117112
%nx = xor i8 %x, -1
@@ -137,9 +132,8 @@ define i1 @fold_cmp_eq_ctpop_c_multiuse_fail(i8 %x) {
137132

138133
define <2 x i1> @fold_cmp_ne_ctpop_c(<2 x i8> %x) {
139134
; CHECK-LABEL: @fold_cmp_ne_ctpop_c(
140-
; CHECK-NEXT: [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
141-
; CHECK-NEXT: [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
142-
; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[CNT]], <i8 44, i8 3>
135+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
136+
; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[TMP1]], <i8 -36, i8 5>
143137
; CHECK-NEXT: ret <2 x i1> [[R]]
144138
;
145139
%nx = xor <2 x i8> %x, <i8 -1, i8 -1>
@@ -163,11 +157,10 @@ define <2 x i1> @fold_cmp_ne_ctpop_var_fail(<2 x i8> %x, <2 x i8> %y) {
163157

164158
define i1 @fold_cmp_ult_ctpop_c(i8 %x, i8 %y, i1 %cond) {
165159
; CHECK-LABEL: @fold_cmp_ult_ctpop_c(
166-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
167-
; CHECK-NEXT: [[NY:%.*]] = add i8 [[Y:%.*]], 15
168-
; CHECK-NEXT: [[N:%.*]] = select i1 [[COND:%.*]], i8 [[NX]], i8 [[NY]]
169-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[N]]), !range [[RNG0]]
170-
; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[CNT]], 5
160+
; CHECK-NEXT: [[TMP1:%.*]] = sub i8 -16, [[Y:%.*]]
161+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND:%.*]], i8 [[X:%.*]], i8 [[TMP1]]
162+
; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP2]]), !range [[RNG0]]
163+
; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[TMP3]], 3
171164
; CHECK-NEXT: ret i1 [[R]]
172165
;
173166
%nx = xor i8 %x, -1
@@ -180,11 +173,10 @@ define i1 @fold_cmp_ult_ctpop_c(i8 %x, i8 %y, i1 %cond) {
180173

181174
define i1 @fold_cmp_sle_ctpop_c(i8 %x, i8 %y, i1 %cond) {
182175
; CHECK-LABEL: @fold_cmp_sle_ctpop_c(
183-
; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
184-
; CHECK-NEXT: [[NY:%.*]] = add i8 [[Y:%.*]], 15
185-
; CHECK-NEXT: [[N:%.*]] = select i1 [[COND:%.*]], i8 [[NX]], i8 [[NY]]
186-
; CHECK-NEXT: [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[N]]), !range [[RNG0]]
187-
; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[CNT]], 4
176+
; CHECK-NEXT: [[TMP1:%.*]] = sub i8 -16, [[Y:%.*]]
177+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND:%.*]], i8 [[X:%.*]], i8 [[TMP1]]
178+
; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP2]]), !range [[RNG0]]
179+
; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[TMP3]], 4
188180
; CHECK-NEXT: ret i1 [[R]]
189181
;
190182
%nx = xor i8 %x, -1
@@ -210,9 +202,8 @@ define i1 @fold_cmp_ult_ctpop_c_no_not_inst_save_fail(i8 %x) {
210202

211203
define <2 x i1> @fold_cmp_ugt_ctpop_c(<2 x i8> %x) {
212204
; CHECK-LABEL: @fold_cmp_ugt_ctpop_c(
213-
; CHECK-NEXT: [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
214-
; CHECK-NEXT: [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
215-
; CHECK-NEXT: [[R:%.*]] = icmp ugt <2 x i8> [[CNT]], <i8 8, i8 6>
205+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
206+
; CHECK-NEXT: [[R:%.*]] = icmp ult <2 x i8> [[TMP1]], <i8 0, i8 2>
216207
; CHECK-NEXT: ret <2 x i1> [[R]]
217208
;
218209
%nx = xor <2 x i8> %x, <i8 -1, i8 -1>

0 commit comments

Comments
 (0)