Skip to content

Commit 45d8955

Browse files
committed
[InstCombine] Fold integer unpack/repack patterns through ZExt (llvm#153583)
This patch explicitly enables the InstCombiner to fold integer unpack/repack patterns such as ```llvm define i64 @src_combine(i32 %lower, i32 %upper) { %base = zext i32 %lower to i64 %u.0 = and i32 %upper, u0xff %z.0 = zext i32 %u.0 to i64 %s.0 = shl i64 %z.0, 32 %o.0 = or i64 %base, %s.0 %r.1 = lshr i32 %upper, 8 %u.1 = and i32 %r.1, u0xff %z.1 = zext i32 %u.1 to i64 %s.1 = shl i64 %z.1, 40 %o.1 = or i64 %o.0, %s.1 %r.2 = lshr i32 %upper, 16 %u.2 = and i32 %r.2, u0xff %z.2 = zext i32 %u.2 to i64 %s.2 = shl i64 %z.2, 48 %o.2 = or i64 %o.1, %s.2 %r.3 = lshr i32 %upper, 24 %u.3 = and i32 %r.3, u0xff %z.3 = zext i32 %u.3 to i64 %s.3 = shl i64 %z.3, 56 %o.3 = or i64 %o.2, %s.3 ret i64 %o.3 } ; => define i64 @tgt_combine(i32 %lower, i32 %upper) { %base = zext i32 %lower to i64 %upper.zext = zext i32 %upper to i64 %s.0 = shl nuw i64 %upper.zext, 32 %o.3 = or disjoint i64 %s.0, %base ret i64 %o.3 } ``` Alive2 proofs: [YAy7ny](https://alive2.llvm.org/ce/z/YAy7ny)
1 parent 128dcb1 commit 45d8955

File tree

4 files changed

+361
-2
lines changed

4 files changed

+361
-2
lines changed

llvm/include/llvm/IR/IRBuilder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,10 +1537,14 @@ class IRBuilderBase {
15371537
return Accum;
15381538
}
15391539

1540-
Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "") {
1540+
Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "",
1541+
bool IsDisjoint = false) {
15411542
if (auto *V = Folder.FoldBinOp(Instruction::Or, LHS, RHS))
15421543
return V;
1543-
return Insert(BinaryOperator::CreateOr(LHS, RHS), Name);
1544+
return Insert(
1545+
IsDisjoint ? BinaryOperator::CreateDisjoint(Instruction::Or, LHS, RHS)
1546+
: BinaryOperator::CreateOr(LHS, RHS),
1547+
Name);
15441548
}
15451549

15461550
Value *CreateOr(Value *LHS, const APInt &RHS, const Twine &Name = "") {

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,6 +3546,109 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35463546
return nullptr;
35473547
}
35483548

3549+
/// Match \p V as "lshr -> mask -> zext -> shl".
3550+
///
3551+
/// \p Int is the underlying integer being extracted from.
3552+
/// \p Mask is a bitmask identifying which bits of the integer are being
3553+
/// extracted. \p Offset identifies which bit of the result \p V corresponds to
3554+
/// the least significant bit of \p Int
3555+
static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
3556+
uint64_t &Offset, bool &IsShlNUW,
3557+
bool &IsShlNSW) {
3558+
Value *ShlOp0;
3559+
uint64_t ShlAmt = 0;
3560+
if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_ConstantInt(ShlAmt)))))
3561+
return false;
3562+
3563+
IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap();
3564+
IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap();
3565+
3566+
Value *ZExtOp0;
3567+
if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0)))))
3568+
return false;
3569+
3570+
Value *MaskedOp0;
3571+
const APInt *ShiftedMaskConst = nullptr;
3572+
if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0),
3573+
m_APInt(ShiftedMaskConst))),
3574+
m_Value(MaskedOp0))))
3575+
return false;
3576+
3577+
uint64_t LShrAmt = 0;
3578+
if (!match(MaskedOp0,
3579+
m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_ConstantInt(LShrAmt))),
3580+
m_Value(Int))))
3581+
return false;
3582+
3583+
if (LShrAmt > ShlAmt)
3584+
return false;
3585+
Offset = ShlAmt - LShrAmt;
3586+
3587+
Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt)
3588+
: APInt::getBitsSetFrom(
3589+
Int->getType()->getScalarSizeInBits(), LShrAmt);
3590+
3591+
return true;
3592+
}
3593+
3594+
/// Try to fold the join of two scalar integers whose bits are unpacked and
3595+
/// zexted from the same source integer.
3596+
static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs,
3597+
InstCombiner::BuilderTy &Builder) {
3598+
3599+
Value *LhsInt, *RhsInt;
3600+
APInt LhsMask, RhsMask;
3601+
uint64_t LhsOffset, RhsOffset;
3602+
bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
3603+
if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
3604+
IsLhsShlNSW))
3605+
return nullptr;
3606+
if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
3607+
IsRhsShlNSW))
3608+
return nullptr;
3609+
if (LhsInt != RhsInt || LhsOffset != RhsOffset)
3610+
return nullptr;
3611+
3612+
APInt Mask = LhsMask | RhsMask;
3613+
3614+
Type *DestTy = Lhs->getType();
3615+
Value *Res = Builder.CreateShl(
3616+
Builder.CreateZExt(
3617+
Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy,
3618+
LhsInt->getName() + ".zext"),
3619+
ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW,
3620+
IsLhsShlNSW && IsRhsShlNSW);
3621+
Res->takeName(Lhs);
3622+
return Res;
3623+
}
3624+
3625+
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
3626+
if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder))
3627+
return Res;
3628+
3629+
return nullptr;
3630+
}
3631+
3632+
Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {
3633+
3634+
Value *X, *Y;
3635+
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3636+
if (Value *Res = foldDisjointOr(LHS, X))
3637+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3638+
if (Value *Res = foldDisjointOr(LHS, Y))
3639+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3640+
}
3641+
3642+
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3643+
if (Value *Res = foldDisjointOr(X, RHS))
3644+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3645+
if (Value *Res = foldDisjointOr(Y, RHS))
3646+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3647+
}
3648+
3649+
return nullptr;
3650+
}
3651+
35493652
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
35503653
// here. We should standardize that construct where it is needed or choose some
35513654
// other way to ensure that commutated variants of patterns are not missed.
@@ -3627,6 +3730,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
36273730
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
36283731
/*NSW=*/true, /*NUW=*/true))
36293732
return R;
3733+
3734+
if (Value *Res = foldDisjointOr(I.getOperand(0), I.getOperand(1)))
3735+
return replaceInstUsesWith(I, Res);
3736+
3737+
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
3738+
return replaceInstUsesWith(I, Res);
36303739
}
36313740

36323741
Value *X, *Y;

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
432432
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
433433
bool IsAnd, bool RHSIsLogical);
434434

435+
Value *foldDisjointOr(Value *LHS, Value *RHS);
436+
437+
Value *reassociateDisjointOr(Value *LHS, Value *RHS);
438+
435439
Instruction *
436440
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);
437441

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes=instcombine %s -S | FileCheck %s
3+
4+
declare void @use.i32(i32)
5+
declare void @use.i64(i64)
6+
7+
define i64 @full_shl(i32 %x) {
8+
; CHECK-LABEL: define i64 @full_shl(
9+
; CHECK-SAME: i32 [[X:%.*]]) {
10+
; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X]] to i64
11+
; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
12+
; CHECK-NEXT: ret i64 [[LO_SHL]]
13+
;
14+
%lo = and i32 %x, u0xffff
15+
%lo.zext = zext nneg i32 %lo to i64
16+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
17+
18+
%hi = lshr i32 %x, 16
19+
%hi.zext = zext nneg i32 %hi to i64
20+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
21+
22+
%res = or disjoint i64 %lo.shl, %hi.shl
23+
ret i64 %res
24+
}
25+
26+
define <2 x i64> @full_shl_vec(<2 x i32> %v) {
27+
; CHECK-LABEL: define <2 x i64> @full_shl_vec(
28+
; CHECK-SAME: <2 x i32> [[V:%.*]]) {
29+
; CHECK-NEXT: [[V_ZEXT:%.*]] = zext <2 x i32> [[V]] to <2 x i64>
30+
; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw <2 x i64> [[V_ZEXT]], splat (i64 24)
31+
; CHECK-NEXT: ret <2 x i64> [[LO_SHL]]
32+
;
33+
%lo = and <2 x i32> %v, splat(i32 u0xffff)
34+
%lo.zext = zext nneg <2 x i32> %lo to <2 x i64>
35+
%lo.shl = shl nuw nsw <2 x i64> %lo.zext, splat(i64 24)
36+
37+
%hi = lshr <2 x i32> %v, splat(i32 16)
38+
%hi.zext = zext nneg <2 x i32> %hi to <2 x i64>
39+
%hi.shl = shl nuw nsw <2 x i64> %hi.zext, splat(i64 40)
40+
41+
%res = or disjoint <2 x i64> %lo.shl, %hi.shl
42+
ret <2 x i64> %res
43+
}
44+
45+
; u0xaabbccdd = -1430532899
46+
define i64 @partial_shl(i32 %x) {
47+
; CHECK-LABEL: define i64 @partial_shl(
48+
; CHECK-SAME: i32 [[X:%.*]]) {
49+
; CHECK-NEXT: [[X_MASK:%.*]] = and i32 [[X]], -1430532899
50+
; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X_MASK]] to i64
51+
; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
52+
; CHECK-NEXT: ret i64 [[LO_SHL]]
53+
;
54+
%lo = and i32 %x, u0xccdd
55+
%lo.zext = zext nneg i32 %lo to i64
56+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
57+
58+
%hi = lshr i32 %x, 16
59+
%hi.mask = and i32 %hi, u0xaabb
60+
%hi.zext = zext nneg i32 %hi.mask to i64
61+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
62+
63+
%res = or disjoint i64 %lo.shl, %hi.shl
64+
ret i64 %res
65+
}
66+
67+
define i64 @shl_multi_use_shl(i32 %x) {
68+
; CHECK-LABEL: define i64 @shl_multi_use_shl(
69+
; CHECK-SAME: i32 [[X:%.*]]) {
70+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24
71+
; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
72+
; CHECK-NEXT: call void @use.i64(i64 [[LO_SHL]])
73+
; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
74+
; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
75+
; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
76+
; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
77+
; CHECK-NEXT: ret i64 [[RES]]
78+
;
79+
%lo = and i32 %x, u0x00ff
80+
%lo.zext = zext nneg i32 %lo to i64
81+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
82+
call void @use.i64(i64 %lo.shl)
83+
84+
%hi = lshr i32 %x, 16
85+
%hi.zext = zext nneg i32 %hi to i64
86+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
87+
88+
%res = or disjoint i64 %lo.shl, %hi.shl
89+
ret i64 %res
90+
}
91+
92+
define i64 @shl_multi_use_zext(i32 %x) {
93+
; CHECK-LABEL: define i64 @shl_multi_use_zext(
94+
; CHECK-SAME: i32 [[X:%.*]]) {
95+
; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 255
96+
; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
97+
; CHECK-NEXT: call void @use.i64(i64 [[LO_ZEXT]])
98+
; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
99+
; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
100+
; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
101+
; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
102+
; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[LO_SHL]], [[HI_SHL]]
103+
; CHECK-NEXT: ret i64 [[RES]]
104+
;
105+
%lo = and i32 %x, u0x00ff
106+
%lo.zext = zext nneg i32 %lo to i64
107+
call void @use.i64(i64 %lo.zext)
108+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
109+
110+
%hi = lshr i32 %x, 16
111+
%hi.zext = zext nneg i32 %hi to i64
112+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
113+
114+
%res = or disjoint i64 %lo.shl, %hi.shl
115+
ret i64 %res
116+
}
117+
118+
define i64 @shl_multi_use_lshr(i32 %x) {
119+
; CHECK-LABEL: define i64 @shl_multi_use_lshr(
120+
; CHECK-SAME: i32 [[X:%.*]]) {
121+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24
122+
; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
123+
; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
124+
; CHECK-NEXT: call void @use.i32(i32 [[HI]])
125+
; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
126+
; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
127+
; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
128+
; CHECK-NEXT: ret i64 [[RES]]
129+
;
130+
%lo = and i32 %x, u0x00ff
131+
%lo.zext = zext nneg i32 %lo to i64
132+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
133+
134+
%hi = lshr i32 %x, 16
135+
call void @use.i32(i32 %hi)
136+
%hi.zext = zext nneg i32 %hi to i64
137+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
138+
139+
%res = or disjoint i64 %lo.shl, %hi.shl
140+
ret i64 %res
141+
}
142+
143+
define i64 @shl_non_disjoint(i32 %x) {
144+
; CHECK-LABEL: define i64 @shl_non_disjoint(
145+
; CHECK-SAME: i32 [[X:%.*]]) {
146+
; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 16711680
147+
; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
148+
; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
149+
; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
150+
; CHECK-NEXT: call void @use.i32(i32 [[HI]])
151+
; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
152+
; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
153+
; CHECK-NEXT: [[RES:%.*]] = or i64 [[LO_SHL]], [[HI_SHL]]
154+
; CHECK-NEXT: ret i64 [[RES]]
155+
;
156+
%lo = and i32 %x, u0x00ff0000
157+
%lo.zext = zext nneg i32 %lo to i64
158+
%lo.shl = shl nuw nsw i64 %lo.zext, 24
159+
160+
%hi = lshr i32 %x, 16
161+
call void @use.i32(i32 %hi)
162+
%hi.zext = zext nneg i32 %hi to i64
163+
%hi.shl = shl nuw nsw i64 %hi.zext, 40
164+
165+
%res = or i64 %lo.shl, %hi.shl
166+
ret i64 %res
167+
}
168+
169+
define i64 @combine(i32 %lower, i32 %upper) {
170+
; CHECK-LABEL: define i64 @combine(
171+
; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
172+
; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
173+
; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i32 [[UPPER]] to i64
174+
; CHECK-NEXT: [[S_0:%.*]] = shl nuw i64 [[UPPER_ZEXT]], 32
175+
; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[S_0]], [[BASE]]
176+
; CHECK-NEXT: ret i64 [[O_3]]
177+
;
178+
%base = zext i32 %lower to i64
179+
180+
%u.0 = and i32 %upper, u0xff
181+
%z.0 = zext i32 %u.0 to i64
182+
%s.0 = shl i64 %z.0, 32
183+
%o.0 = or i64 %base, %s.0
184+
185+
%r.1 = lshr i32 %upper, 8
186+
%u.1 = and i32 %r.1, u0xff
187+
%z.1 = zext i32 %u.1 to i64
188+
%s.1 = shl i64 %z.1, 40
189+
%o.1 = or i64 %o.0, %s.1
190+
191+
%r.2 = lshr i32 %upper, 16
192+
%u.2 = and i32 %r.2, u0xff
193+
%z.2 = zext i32 %u.2 to i64
194+
%s.2 = shl i64 %z.2, 48
195+
%o.2 = or i64 %o.1, %s.2
196+
197+
%r.3 = lshr i32 %upper, 24
198+
%u.3 = and i32 %r.3, u0xff
199+
%z.3 = zext i32 %u.3 to i64
200+
%s.3 = shl i64 %z.3, 56
201+
%o.3 = or i64 %o.2, %s.3
202+
203+
ret i64 %o.3
204+
}
205+
206+
define i64 @combine_2(i32 %lower, i32 %upper) {
207+
; CHECK-LABEL: define i64 @combine_2(
208+
; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
209+
; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
210+
; CHECK-NEXT: [[S_03:%.*]] = zext i32 [[UPPER]] to i64
211+
; CHECK-NEXT: [[O:%.*]] = shl nuw i64 [[S_03]], 32
212+
; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[O]], [[BASE]]
213+
; CHECK-NEXT: ret i64 [[RES]]
214+
;
215+
%base = zext i32 %lower to i64
216+
217+
%u.0 = and i32 %upper, u0xff
218+
%z.0 = zext i32 %u.0 to i64
219+
%s.0 = shl i64 %z.0, 32
220+
221+
%r.1 = lshr i32 %upper, 8
222+
%u.1 = and i32 %r.1, u0xff
223+
%z.1 = zext i32 %u.1 to i64
224+
%s.1 = shl i64 %z.1, 40
225+
%o.1 = or i64 %s.0, %s.1
226+
227+
%r.2 = lshr i32 %upper, 16
228+
%u.2 = and i32 %r.2, u0xff
229+
%z.2 = zext i32 %u.2 to i64
230+
%s.2 = shl i64 %z.2, 48
231+
232+
%r.3 = lshr i32 %upper, 24
233+
%u.3 = and i32 %r.3, u0xff
234+
%z.3 = zext i32 %u.3 to i64
235+
%s.3 = shl i64 %z.3, 56
236+
%o.3 = or i64 %s.2, %s.3
237+
238+
%o = or i64 %o.1, %o.3
239+
%res = or i64 %o, %base
240+
241+
ret i64 %res
242+
}

0 commit comments

Comments
 (0)