From 01055b1a6f565d8d10ba7efa2206b8f0dac84cca Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 6 Jan 2025 00:21:58 +0800 Subject: [PATCH 1/4] [InstCombine] Add pre-commit tests. NFC. --- .../Transforms/InstCombine/add-shl-sdiv-to-srem.ll | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll index 84462f9a7f592..60bfe3f8665b7 100644 --- a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll +++ b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll @@ -12,6 +12,19 @@ define i8 @add-shl-sdiv-scalar0(i8 %x) { ret i8 %rz } +define i8 @add-shl-sdiv-scalar0_commuted(i8 %x) { +; CHECK-LABEL: @add-shl-sdiv-scalar0_commuted( +; CHECK-NEXT: [[SD:%.*]] = sdiv i8 [[X:%.*]], -4 +; CHECK-NEXT: [[SL:%.*]] = shl i8 [[SD]], 2 +; CHECK-NEXT: [[RZ:%.*]] = add i8 [[X]], [[SL]] +; CHECK-NEXT: ret i8 [[RZ]] +; + %sd = sdiv i8 %x, -4 + %sl = shl i8 %sd, 2 + %rz = add i8 %x, %sl + ret i8 %rz +} + define i8 @add-shl-sdiv-scalar1(i8 %x) { ; CHECK-LABEL: @add-shl-sdiv-scalar1( ; CHECK-NEXT: [[RZ:%.*]] = srem i8 [[X:%.*]], 64 From 79c303869c61e8d4479aaaa29b8b677c45a9318b Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 6 Jan 2025 00:50:45 +0800 Subject: [PATCH 2/4] [InstCombine] Handle commuted pattern for `((X s/ C1) << C2) + X` --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 7 ++++--- llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll | 4 +--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 7a184a19d7c54..74d17067de16e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1625,12 +1625,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 const APInt *C1, *C2; - if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + if (match(&I, m_c_Add(m_Shl(m_SDiv(m_Value(A), m_APInt(C1)), m_APInt(C2)), + m_Deferred(A)))) { APInt one(C2->getBitWidth(), 1); APInt minusC1 = -(*C1); if (minusC1 == (one << *C2)) { - Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); - return BinaryOperator::CreateSRem(RHS, NewRHS); + Constant *NewRHS = ConstantInt::get(A->getType(), minusC1); + return BinaryOperator::CreateSRem(A, NewRHS); } } diff --git a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll index 60bfe3f8665b7..d4edf12eba6ac 100644 --- a/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll +++ b/llvm/test/Transforms/InstCombine/add-shl-sdiv-to-srem.ll @@ -14,9 +14,7 @@ define i8 @add-shl-sdiv-scalar0(i8 %x) { define i8 @add-shl-sdiv-scalar0_commuted(i8 %x) { ; CHECK-LABEL: @add-shl-sdiv-scalar0_commuted( -; CHECK-NEXT: [[SD:%.*]] = sdiv i8 [[X:%.*]], -4 -; CHECK-NEXT: [[SL:%.*]] = shl i8 [[SD]], 2 -; CHECK-NEXT: [[RZ:%.*]] = add i8 [[X]], [[SL]] +; CHECK-NEXT: [[RZ:%.*]] = srem i8 [[X:%.*]], 4 ; CHECK-NEXT: ret i8 [[RZ]] ; %sd = sdiv i8 %x, -4 From 31a6f4a6dd81a6fe5ef956069d24826c9ccee129 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 6 Jan 2025 18:22:31 +0800 Subject: [PATCH 3/4] [InstCombine] Move the logic into `foldAddLikeCommutative` --- .../InstCombine/InstCombineAddSub.cpp | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 74d17067de16e..dee07b260dcd3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1326,6 +1326,18 @@ Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS, R->setHasNoUnsignedWrap(NUWOut); return R; } + + // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 + const APInt *C1, *C2; + if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + APInt one(C2->getBitWidth(), 1); + APInt minusC1 = -(*C1); + if (minusC1 == (one << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); + return BinaryOperator::CreateSRem(RHS, NewRHS); + } + } + return nullptr; } @@ -1623,18 +1635,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); - // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 - const APInt *C1, *C2; - if (match(&I, m_c_Add(m_Shl(m_SDiv(m_Value(A), m_APInt(C1)), m_APInt(C2)), - m_Deferred(A)))) { - APInt one(C2->getBitWidth(), 1); - APInt minusC1 = -(*C1); - if (minusC1 == (one << *C2)) { - Constant *NewRHS = ConstantInt::get(A->getType(), minusC1); - return BinaryOperator::CreateSRem(A, NewRHS); - } - } - + const APInt *C1; // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) { From 4a615c32a3d985f0a663122b68035db8da84a813 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Mon, 6 Jan 2025 19:02:32 +0800 Subject: [PATCH 4/4] [InstCombine] Fix coding style. NFC. --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index dee07b260dcd3..a2769c96b2ef4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1330,10 +1330,10 @@ Instruction *InstCombinerImpl::foldAddLikeCommutative(Value *LHS, Value *RHS, // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 const APInt *C1, *C2; if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { - APInt one(C2->getBitWidth(), 1); - APInt minusC1 = -(*C1); - if (minusC1 == (one << *C2)) { - Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); + APInt One(C2->getBitWidth(), 1); + APInt MinusC1 = -(*C1); + if (MinusC1 == (One << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), MinusC1); return BinaryOperator::CreateSRem(RHS, NewRHS); } }