-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[InstCombine] Added optimization for shift add #163502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: None (manik-muk) ChangesAddresses #163115 Full diff: https://github.com/llvm/llvm-project/pull/163502.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index d457e0c7dd1c4..fc2a0018e725c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1803,6 +1803,30 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
return NewAdd;
}
+
+ // Fold ((X << A) + C) >> B --> (X << (A - B)) + (C >> B)
+ // when the shift is exact and the add is nsw.
+ // This transforms patterns like: ((x << 4) + 16) ashr exact 1 --> (x <<
+ // 3) + 8
+ const APInt *ShlAmt, *AddC;
+ if (I.isExact() &&
+ match(Op0, m_c_NSWAdd(m_NSWShl(m_Value(X), m_APInt(ShlAmt)),
+ m_APInt(AddC))) &&
+ ShlAmt->uge(ShAmt)) {
+ // Check if C is divisible by (1 << ShAmt)
+ if (AddC->isShiftedMask() || AddC->countTrailingZeros() >= ShAmt ||
+ AddC->ashr(ShAmt).shl(ShAmt) == *AddC) {
+ // X << (A - B)
+ Constant *NewShlAmt = ConstantInt::get(Ty, *ShlAmt - ShAmt);
+ Value *NewShl = Builder.CreateShl(X, NewShlAmt);
+
+ // C >> B
+ Constant *NewAddC = ConstantInt::get(Ty, AddC->ashr(ShAmt));
+
+ // (X << (A - B)) + (C >> B)
+ return BinaryOperator::CreateAdd(NewShl, NewAddC);
+ }
+ }
}
const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll
index 81cbc2ac23b5f..1d1f219904f74 100644
--- a/llvm/test/Transforms/InstCombine/shift-add.ll
+++ b/llvm/test/Transforms/InstCombine/shift-add.ll
@@ -804,3 +804,147 @@ define <2 x i8> @lshr_fold_or_disjoint_cnt_out_of_bounds(<2 x i8> %x) {
%r = lshr <2 x i8> <i8 2, i8 3>, %a
ret <2 x i8> %r
}
+
+define i32 @ashr_exact_add_shl_fold(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Test with larger shift amounts
+define i32 @ashr_exact_add_shl_fold_larger_shift(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_larger_shift(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 2
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 3
+ ret i32 %v2
+}
+
+; Test with negative constant
+define i32 @ashr_exact_add_shl_fold_negative_const(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_negative_const(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 2
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], -4
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, -16
+ %v2 = ashr exact i32 %v1, 2
+ ret i32 %v2
+}
+
+; Test where shift amount equals shl amount (result is just the constant)
+define i32 @ashr_exact_add_shl_fold_equal_shifts(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_equal_shifts(
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[ARG0:%.*]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 4
+ ret i32 %v2
+}
+
+; Negative test: not exact - should not transform
+define i32 @ashr_add_shl_no_exact(i32 %arg0) {
+; CHECK-LABEL: @ashr_add_shl_no_exact(
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[TMP1]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: add is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_add(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_add(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: shl is not nsw - should not transform
+define i32 @ashr_exact_add_shl_no_nsw_shl(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_no_nsw_shl(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 1
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: constant not divisible by shift amount
+define i32 @ashr_exact_add_shl_not_divisible(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_not_divisible(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 4
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 17
+; CHECK-NEXT: ret i32 [[V1]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 %v0, 17
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
+
+; Negative test: shift amount greater than shl amount
+define i32 @ashr_exact_add_shl_shift_too_large(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_shift_too_large(
+; CHECK-NEXT: [[V0:%.*]] = shl nsw i32 [[ARG0:%.*]], 2
+; CHECK-NEXT: [[V1:%.*]] = add nsw i32 [[V0]], 16
+; CHECK-NEXT: [[V2:%.*]] = ashr exact i32 [[V1]], 4
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 2
+ %v1 = add nsw i32 %v0, 16
+ %v2 = ashr exact i32 %v1, 4
+ ret i32 %v2
+}
+
+; Vector test
+define <2 x i32> @ashr_exact_add_shl_fold_vector(<2 x i32> %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_vector(
+; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i32> [[ARG0:%.*]], splat (i32 3)
+; CHECK-NEXT: [[V2:%.*]] = add <2 x i32> [[TMP1]], splat (i32 8)
+; CHECK-NEXT: ret <2 x i32> [[V2]]
+;
+ %v0 = shl nsw <2 x i32> %arg0, <i32 4, i32 4>
+ %v1 = add nsw <2 x i32> %v0, <i32 16, i32 16>
+ %v2 = ashr exact <2 x i32> %v1, <i32 1, i32 1>
+ ret <2 x i32> %v2
+}
+
+; Test commutative add (constant on left)
+define i32 @ashr_exact_add_shl_fold_commute(i32 %arg0) {
+; CHECK-LABEL: @ashr_exact_add_shl_fold_commute(
+; CHECK-NEXT: [[V0:%.*]] = shl i32 [[ARG0:%.*]], 3
+; CHECK-NEXT: [[V2:%.*]] = add i32 [[V0]], 8
+; CHECK-NEXT: ret i32 [[V2]]
+;
+ %v0 = shl nsw i32 %arg0, 4
+ %v1 = add nsw i32 16, %v0
+ %v2 = ashr exact i32 %v1, 1
+ ret i32 %v2
+}
|
|
Please see https://llvm.org/docs/InstCombineContributorGuide.html. In particular this is missing generalized alive2 proofs. |
|
Note that this also applies to lshr, just with nuw instead of nsw: https://alive2.llvm.org/ce/z/Hpt_LH Ideally, we would cover this by extensions to canEvaluateShifted() instead of dedicated patterns. For example, can we handle the lshr variant by adding |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| // Fold ((X << A) + C) >>u B --> (X << (A - B)) + (C >>u B) | ||
| // when the shift is exact and the add has nuw. | ||
| const APInt *ShAmtAPInt, *ShlAmt, *AddC; | ||
| if (match(Op1, m_APInt(ShAmtAPInt)) && I.isExact() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this and the ashr variant? Isn't it covered by the changes in canEvaluateShifted and getShiftedValue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so! canEvaluateShifted always excludes Ashr because ashr is an arithmetic shift and calEvaluateShifted is for logical shifts.
lshr does use the canEvaluateShifted function but has some stricter requirements, hence the extra logic
Addresses #163115
alive2 proof: https://alive2.llvm.org/ce/z/sumyA7