@@ -117,7 +117,7 @@ class VectorCombine {
117117 bool foldShuffleOfShuffles (Instruction &I);
118118 bool foldShuffleToIdentity (Instruction &I);
119119 bool foldShuffleFromReductions (Instruction &I);
120- bool foldTruncFromReductions (Instruction &I);
120+ bool foldCastFromReductions (Instruction &I);
121121 bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
122122
123123 void replaceValue (Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
21132113
21142114// / Determine if its more efficient to fold:
21152115// / reduce(trunc(x)) -> trunc(reduce(x)).
2116- bool VectorCombine::foldTruncFromReductions (Instruction &I) {
2116+ // / reduce(sext(x)) -> sext(reduce(x)).
2117+ // / reduce(zext(x)) -> zext(reduce(x)).
2118+ bool VectorCombine::foldCastFromReductions (Instruction &I) {
21172119 auto *II = dyn_cast<IntrinsicInst>(&I);
21182120 if (!II)
21192121 return false ;
21202122
2123+ bool TruncOnly = false ;
21212124 Intrinsic::ID IID = II->getIntrinsicID ();
21222125 switch (IID) {
21232126 case Intrinsic::vector_reduce_add:
21242127 case Intrinsic::vector_reduce_mul:
2128+ TruncOnly = true ;
2129+ break ;
21252130 case Intrinsic::vector_reduce_and:
21262131 case Intrinsic::vector_reduce_or:
21272132 case Intrinsic::vector_reduce_xor:
@@ -2133,35 +2138,37 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
21332138 unsigned ReductionOpc = getArithmeticReductionInstruction (IID);
21342139 Value *ReductionSrc = I.getOperand (0 );
21352140
2136- Value *TruncSrc;
2137- if (!match (ReductionSrc, m_OneUse (m_Trunc (m_Value (TruncSrc)))))
2141+ Value *Src;
2142+ if (!match (ReductionSrc, m_OneUse (m_Trunc (m_Value (Src)))) &&
2143+ (TruncOnly || !match (ReductionSrc, m_OneUse (m_ZExtOrSExt (m_Value (Src))))))
21382144 return false ;
21392145
2140- auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType ());
2146+ auto CastOpc =
2147+ (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode ();
2148+
2149+ auto *SrcTy = cast<VectorType>(Src->getType ());
21412150 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType ());
21422151 Type *ResultTy = I.getType ();
21432152
21442153 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21452154 InstructionCost OldCost = TTI.getArithmeticReductionCost (
21462155 ReductionOpc, ReductionSrcTy, std::nullopt , CostKind);
2147- if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
2148- OldCost +=
2149- TTI.getCastInstrCost (Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
2150- TTI::CastContextHint::None, CostKind, Trunc);
2156+ OldCost += TTI.getCastInstrCost (CastOpc, ReductionSrcTy, SrcTy,
2157+ TTI::CastContextHint::None, CostKind,
2158+ cast<CastInst>(ReductionSrc));
21512159 InstructionCost NewCost =
2152- TTI.getArithmeticReductionCost (ReductionOpc, TruncSrcTy , std::nullopt ,
2160+ TTI.getArithmeticReductionCost (ReductionOpc, SrcTy , std::nullopt ,
21532161 CostKind) +
2154- TTI.getCastInstrCost (Instruction::Trunc, ResultTy,
2155- ReductionSrcTy->getScalarType (),
2162+ TTI.getCastInstrCost (CastOpc, ResultTy, ReductionSrcTy->getScalarType (),
21562163 TTI::CastContextHint::None, CostKind);
21572164
21582165 if (OldCost <= NewCost || !NewCost.isValid ())
21592166 return false ;
21602167
2161- Value *NewReduction = Builder.CreateIntrinsic (
2162- TruncSrcTy-> getScalarType (), II->getIntrinsicID (), {TruncSrc });
2163- Value *NewTruncation = Builder.CreateTrunc ( NewReduction, ResultTy);
2164- replaceValue (I, *NewTruncation );
2168+ Value *NewReduction = Builder.CreateIntrinsic (SrcTy-> getScalarType (),
2169+ II->getIntrinsicID (), {Src });
2170+ Value *NewCast = Builder.CreateCast (CastOpc, NewReduction, ResultTy);
2171+ replaceValue (I, *NewCast );
21652172 return true ;
21662173}
21672174
@@ -2559,7 +2566,7 @@ bool VectorCombine::run() {
25592566 switch (Opcode) {
25602567 case Instruction::Call:
25612568 MadeChange |= foldShuffleFromReductions (I);
2562- MadeChange |= foldTruncFromReductions (I);
2569+ MadeChange |= foldCastFromReductions (I);
25632570 break ;
25642571 case Instruction::ICmp:
25652572 case Instruction::FCmp:
0 commit comments