Skip to content

Commit a5f0824

Browse files
committed
[InstCombine] Dropping redundant masking before left-shift [0/5] (PR42563)
Summary: If we have some pattern that leaves only some low bits set, and then performs left-shift of those bits, if none of the bits that are left after the final shift are modified by the mask, we can omit the mask. There are many variants to this pattern: a. `(x & ((1 << MaskShAmt) - 1)) << ShiftShAmt` All these patterns can be simplified to just: `x << ShiftShAmt` iff: a. `(MaskShAmt+ShiftShAmt) u>= bitwidth(x)` alive proof: a: https://rise4fun.com/Alive/wi9 Indeed, not all of these patterns are canonical. But since this fold will only produce a single instruction i'm really interested in handling even uncanonical patterns, since i have this general kind of pattern in hotpaths, and it is not totally outlandish for bit-twiddling code. For now let's start with patterns where both shift amounts are variable, with trivial constant "offset" between them, since i believe this is both simplest to handle and i think this is most common. But again, there are likely other variants where we could use ValueTracking/ConstantRange to handle more cases. https://bugs.llvm.org/show_bug.cgi?id=42563 Reviewers: spatel, nikic, huihuiz, xbolva00 Reviewed By: xbolva00 Subscribers: efriedma, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D64512 llvm-svn: 366535
1 parent 3628d94 commit a5f0824

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,53 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
6565
return NewShift;
6666
}
6767

68+
// If we have some pattern that leaves only some low bits set, and then performs
69+
// left-shift of those bits, if none of the bits that are left after the final
70+
// shift are modified by the mask, we can omit the mask.
71+
//
72+
// There are many variants to this pattern:
73+
// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
74+
// All these patterns can be simplified to just:
75+
// x << ShiftShAmt
76+
// iff:
77+
// a) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
78+
static Instruction *
79+
dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
80+
const SimplifyQuery &SQ) {
81+
assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
82+
"The input must be 'shl'!");
83+
84+
Value *Masked = OuterShift->getOperand(0);
85+
Value *ShiftShAmt = OuterShift->getOperand(1);
86+
87+
Value *MaskShAmt;
88+
89+
// ((1 << MaskShAmt) - 1)
90+
auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
91+
92+
Value *X;
93+
if (!match(Masked, m_c_And(MaskA, m_Value(X))))
94+
return nullptr;
95+
96+
// Can we simplify (MaskShAmt+ShiftShAmt) ?
97+
Value *SumOfShAmts =
98+
SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
99+
SQ.getWithInstruction(OuterShift));
100+
if (!SumOfShAmts)
101+
return nullptr; // Did not simplify.
102+
// Is the total shift amount *not* smaller than the bit width?
103+
// FIXME: could also rely on ConstantRange.
104+
unsigned BitWidth = X->getType()->getScalarSizeInBits();
105+
if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
106+
APInt(BitWidth, BitWidth))))
107+
return nullptr;
108+
// All good, we can do this fold.
109+
110+
// No 'NUW'/'NSW'!
111+
// We no longer know that we won't shift-out non-0 bits.
112+
return BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt);
113+
}
114+
68115
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
69116
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
70117
assert(Op0->getType() == Op1->getType());
@@ -629,6 +676,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
629676
if (Instruction *V = commonShiftTransforms(I))
630677
return V;
631678

679+
if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, SQ))
680+
return V;
681+
632682
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
633683
Type *Ty = I.getType();
634684
unsigned BitWidth = Ty->getScalarSizeInBits();

llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-a.ll

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ define i32 @t0_basic(i32 %x, i32 %nbits) {
2525
; CHECK-NEXT: call void @use32(i32 [[T1]])
2626
; CHECK-NEXT: call void @use32(i32 [[T2]])
2727
; CHECK-NEXT: call void @use32(i32 [[T3]])
28-
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
28+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
2929
; CHECK-NEXT: ret i32 [[T4]]
3030
;
3131
%t0 = shl i32 1, %nbits
@@ -50,7 +50,7 @@ define i32 @t1_bigger_shift(i32 %x, i32 %nbits) {
5050
; CHECK-NEXT: call void @use32(i32 [[T1]])
5151
; CHECK-NEXT: call void @use32(i32 [[T2]])
5252
; CHECK-NEXT: call void @use32(i32 [[T3]])
53-
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
53+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
5454
; CHECK-NEXT: ret i32 [[T4]]
5555
;
5656
%t0 = shl i32 1, %nbits
@@ -77,7 +77,7 @@ define i32 @t2_bigger_mask(i32 %x, i32 %nbits) {
7777
; CHECK-NEXT: call void @use32(i32 [[T2]])
7878
; CHECK-NEXT: call void @use32(i32 [[T3]])
7979
; CHECK-NEXT: call void @use32(i32 [[T4]])
80-
; CHECK-NEXT: [[T5:%.*]] = shl i32 [[T3]], [[T4]]
80+
; CHECK-NEXT: [[T5:%.*]] = shl i32 [[X]], [[T4]]
8181
; CHECK-NEXT: ret i32 [[T5]]
8282
;
8383
%t0 = add i32 %nbits, 1
@@ -109,7 +109,7 @@ define <3 x i32> @t3_vec_splat(<3 x i32> %x, <3 x i32> %nbits) {
109109
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
110110
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
111111
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
112-
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
112+
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
113113
; CHECK-NEXT: ret <3 x i32> [[T5]]
114114
;
115115
%t0 = add <3 x i32> %nbits, <i32 0, i32 0, i32 0>
@@ -138,7 +138,7 @@ define <3 x i32> @t4_vec_nonsplat(<3 x i32> %x, <3 x i32> %nbits) {
138138
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
139139
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
140140
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
141-
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
141+
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
142142
; CHECK-NEXT: ret <3 x i32> [[T5]]
143143
;
144144
%t0 = add <3 x i32> %nbits, <i32 -1, i32 0, i32 1>
@@ -166,7 +166,7 @@ define <3 x i32> @t5_vec_undef(<3 x i32> %x, <3 x i32> %nbits) {
166166
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
167167
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]])
168168
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]])
169-
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]]
169+
; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]]
170170
; CHECK-NEXT: ret <3 x i32> [[T5]]
171171
;
172172
%t0 = add <3 x i32> %nbits, <i32 0, i32 undef, i32 0>
@@ -198,7 +198,7 @@ define i32 @t6_commutativity0(i32 %nbits) {
198198
; CHECK-NEXT: call void @use32(i32 [[T1]])
199199
; CHECK-NEXT: call void @use32(i32 [[T2]])
200200
; CHECK-NEXT: call void @use32(i32 [[T3]])
201-
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]]
201+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
202202
; CHECK-NEXT: ret i32 [[T4]]
203203
;
204204
%x = call i32 @gen32()
@@ -260,7 +260,7 @@ define i32 @t8_commutativity2(i32 %nbits0, i32 %nbits1) {
260260
; CHECK-NEXT: call void @use32(i32 [[T3]])
261261
; CHECK-NEXT: call void @use32(i32 [[T4]])
262262
; CHECK-NEXT: call void @use32(i32 [[T5]])
263-
; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T4]], [[T5]]
263+
; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T1]], [[T5]]
264264
; CHECK-NEXT: ret i32 [[T6]]
265265
;
266266
%t0 = shl i32 1, %nbits0
@@ -291,7 +291,7 @@ define i32 @t9_nuw(i32 %x, i32 %nbits) {
291291
; CHECK-NEXT: call void @use32(i32 [[T1]])
292292
; CHECK-NEXT: call void @use32(i32 [[T2]])
293293
; CHECK-NEXT: call void @use32(i32 [[T3]])
294-
; CHECK-NEXT: [[T4:%.*]] = shl nuw i32 [[T2]], [[T3]]
294+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
295295
; CHECK-NEXT: ret i32 [[T4]]
296296
;
297297
%t0 = shl i32 1, %nbits
@@ -316,7 +316,7 @@ define i32 @t10_nsw(i32 %x, i32 %nbits) {
316316
; CHECK-NEXT: call void @use32(i32 [[T1]])
317317
; CHECK-NEXT: call void @use32(i32 [[T2]])
318318
; CHECK-NEXT: call void @use32(i32 [[T3]])
319-
; CHECK-NEXT: [[T4:%.*]] = shl nsw i32 [[T2]], [[T3]]
319+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
320320
; CHECK-NEXT: ret i32 [[T4]]
321321
;
322322
%t0 = shl i32 1, %nbits
@@ -341,7 +341,7 @@ define i32 @t11_nuw_nsw(i32 %x, i32 %nbits) {
341341
; CHECK-NEXT: call void @use32(i32 [[T1]])
342342
; CHECK-NEXT: call void @use32(i32 [[T2]])
343343
; CHECK-NEXT: call void @use32(i32 [[T3]])
344-
; CHECK-NEXT: [[T4:%.*]] = shl nuw nsw i32 [[T2]], [[T3]]
344+
; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]]
345345
; CHECK-NEXT: ret i32 [[T4]]
346346
;
347347
%t0 = shl i32 1, %nbits

0 commit comments

Comments
 (0)