|
64 | 64 | #include "llvm/Support/KnownBits.h" |
65 | 65 | #include "llvm/Support/KnownFPClass.h" |
66 | 66 | #include "llvm/Support/MathExtras.h" |
| 67 | +#include "llvm/Support/TypeSize.h" |
67 | 68 | #include "llvm/Support/raw_ostream.h" |
68 | 69 | #include "llvm/Transforms/InstCombine/InstCombiner.h" |
69 | 70 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" |
@@ -3769,29 +3770,24 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { |
3769 | 3770 | // %4 = tail call i32 @llvm.vector.reduce.add.v4i32(%3) |
3770 | 3771 | // => |
3771 | 3772 | // %2 = shl i32 %0, 2 |
3772 | | - Value *InputValue; |
3773 | | - ArrayRef<int> Mask; |
3774 | | - ConstantInt *InsertionIdx; |
3775 | 3773 | assert(Arg->getType()->isVectorTy() && |
3776 | 3774 | "The vector.reduce.add intrinsic's argument must be a vector!"); |
3777 | 3775 |
|
3778 | | - if (match(Arg, m_Shuffle(m_InsertElt(m_Poison(), m_Value(InputValue), |
3779 | | - m_ConstantInt(InsertionIdx)), |
3780 | | - m_Poison(), m_Mask(Mask)))) { |
| 3776 | + if (Value *Splat = getSplatValue(Arg)) { |
3781 | 3777 | // It is only a multiplication if we add the same element over and over. |
3782 | | - bool AllElementsAreTheSameInMask = |
3783 | | - std::all_of(Mask.begin(), Mask.end(), |
3784 | | - [&Mask](int MaskElt) { return MaskElt == Mask[0]; }); |
3785 | | - unsigned ReducedVectorLength = Mask.size(); |
3786 | | - |
3787 | | - if (AllElementsAreTheSameInMask && |
3788 | | - InsertionIdx->getSExtValue() == Mask[0] && |
3789 | | - isPowerOf2_32(ReducedVectorLength)) { |
3790 | | - unsigned Pow2 = Log2_32(ReducedVectorLength); |
3791 | | - Value *Res = Builder.CreateShl( |
3792 | | - InputValue, Constant::getIntegerValue(InputValue->getType(), |
3793 | | - APInt(32, Pow2))); |
3794 | | - return replaceInstUsesWith(CI, Res); |
| 3778 | + ElementCount ReducedVectorElementCount = |
| 3779 | + static_cast<VectorType *>(Arg->getType())->getElementCount(); |
| 3780 | + if (ReducedVectorElementCount.isFixed()) { |
| 3781 | + unsigned VectorSize = ReducedVectorElementCount.getFixedValue(); |
| 3782 | + if (isPowerOf2_32(VectorSize)) { |
| 3783 | + unsigned Pow2 = Log2_32(VectorSize); |
| 3784 | + Value *Res = Builder.CreateShl( |
| 3785 | + Splat, |
| 3786 | + Constant::getIntegerValue( |
| 3787 | + Splat->getType(), |
| 3788 | + APInt(Splat->getType()->getIntegerBitWidth(), Pow2))); |
| 3789 | + return replaceInstUsesWith(CI, Res); |
| 3790 | + } |
3795 | 3791 | } |
3796 | 3792 | } |
3797 | 3793 | } |
|
0 commit comments