@@ -311,18 +311,27 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
311311 std::optional<unsigned > Opcode;
312312 VPValue *Op = getVecOp ();
313313 uint64_t MulConst;
314+
315+ InstructionCost CondCost = 0 ;
316+ if (isConditional ()) {
317+ CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
318+ auto *VecTy = Ctx.Types .inferScalarType (Op);
319+ auto *CondTy = Ctx.Types .inferScalarType (getCondOp ());
320+ CondCost = Ctx.TTI .getCmpSelInstrCost (Instruction::Select, VecTy, CondTy,
321+ Pred, Ctx.CostKind );
322+ }
323+
314324 // If the partial reduction is predicated, a select will be operand 1.
315325 // If it isn't predicated and the mul isn't operating on a constant, then it
316326 // should have been turned into a VPExpressionRecipe.
317327 // FIXME: Replace the entire function with this once all partial reduction
318328 // variants are bundled into VPExpressionRecipe.
319- if (!match (Op, m_Select (m_VPValue (), m_VPValue (Op), m_VPValue ())) &&
320- !match (Op, m_Mul (m_VPValue (), m_ConstantInt (MulConst)))) {
329+ if (!match (Op, m_Mul (m_VPValue (), m_ConstantInt (MulConst)))) {
321330 auto *PhiType = Ctx.Types .inferScalarType (getChainOp ());
322331 auto *InputType = Ctx.Types .inferScalarType (getVecOp ());
323- return Ctx.TTI .getPartialReductionCost (getOpcode (), InputType, InputType,
324- PhiType, VF, TTI::PR_None ,
325- TTI::PR_None, {}, Ctx.CostKind );
332+ return CondCost + Ctx.TTI .getPartialReductionCost (
333+ getOpcode (), InputType, InputType, PhiType, VF,
334+ TTI::PR_None, TTI::PR_None, {}, Ctx.CostKind );
326335 }
327336
328337 VPRecipeBase *OpR = Op->getDefiningRecipe ();
@@ -381,12 +390,13 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
381390 } else if (auto Widen = dyn_cast<VPWidenRecipe>(OpR)) {
382391 HandleWiden (Widen);
383392 } else if (auto Reduction = dyn_cast<VPPartialReductionRecipe>(OpR)) {
384- return Reduction->computeCost (VF, Ctx);
393+ return CondCost + Reduction->computeCost (VF, Ctx);
385394 }
386395 auto *PhiType = Ctx.Types .inferScalarType (getOperand (1 ));
387- return Ctx.TTI .getPartialReductionCost (getOpcode (), InputTypeA, InputTypeB,
388- PhiType, VF, ExtAType, ExtBType,
389- Opcode, Ctx.CostKind );
396+ return CondCost + Ctx.TTI .getPartialReductionCost (
397+ getOpcode (), InputTypeA, InputTypeB, PhiType, VF,
398+ ExtAType, ExtBType, Opcode, Ctx.CostKind );
399+ ;
390400}
391401
392402void VPPartialReductionRecipe::execute (VPTransformState &State) {
@@ -395,12 +405,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
395405 assert (getOpcode () == Instruction::Add &&
396406 " Unhandled partial reduction opcode" );
397407
398- Value *BinOpVal = State.get (getOperand ( 1 ));
399- Value *PhiVal = State.get (getOperand ( 0 ));
408+ Value *BinOpVal = State.get (getVecOp ( ));
409+ Value *PhiVal = State.get (getChainOp ( ));
400410 assert (PhiVal && BinOpVal && " Phi and Mul must be set" );
401411
402412 Type *RetTy = PhiVal->getType ();
403413
414+ if (isConditional ()) {
415+ Value *Cond = State.get (getCondOp ());
416+ Value *Zero = ConstantInt::get (BinOpVal->getType (), 0 );
417+ BinOpVal = Builder.CreateSelect (Cond, BinOpVal, Zero);
418+ }
419+
404420 CallInst *V =
405421 Builder.CreateIntrinsic (RetTy, Intrinsic::vector_partial_reduce_add,
406422 {PhiVal, BinOpVal}, nullptr , " partial.reduce" );
0 commit comments