@@ -307,94 +307,11 @@ bool VPRecipeBase::isScalarCast() const {
307307InstructionCost
308308VPPartialReductionRecipe::computeCost (ElementCount VF,
309309 VPCostContext &Ctx) const {
310- std::optional<unsigned > Opcode;
311- VPValue *Op = getVecOp ();
312- uint64_t MulConst;
313-
314- InstructionCost CondCost = 0 ;
315- if (isConditional ()) {
316- CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
317- auto *VecTy = Ctx.Types .inferScalarType (Op);
318- auto *CondTy = Ctx.Types .inferScalarType (getCondOp ());
319- CondCost = Ctx.TTI .getCmpSelInstrCost (Instruction::Select, VecTy, CondTy,
320- Pred, Ctx.CostKind );
321- }
322-
323- // If the partial reduction is predicated, a select will be operand 1.
324- // If it isn't predicated and the mul isn't operating on a constant, then it
325- // should have been turned into a VPExpressionRecipe.
326- // FIXME: Replace the entire function with this once all partial reduction
327- // variants are bundled into VPExpressionRecipe.
328- if (!match (Op, m_Mul (m_VPValue (), m_ConstantInt (MulConst)))) {
329- auto *PhiType = Ctx.Types .inferScalarType (getChainOp ());
330- auto *InputType = Ctx.Types .inferScalarType (getVecOp ());
331- return CondCost + Ctx.TTI .getPartialReductionCost (
332- getOpcode (), InputType, InputType, PhiType, VF,
333- TTI::PR_None, TTI::PR_None, {}, Ctx.CostKind );
334- }
335-
336- VPRecipeBase *OpR = Op->getDefiningRecipe ();
337- Type *InputTypeA = nullptr , *InputTypeB = nullptr ;
338- TTI::PartialReductionExtendKind ExtAType = TTI::PR_None,
339- ExtBType = TTI::PR_None;
340-
341- auto GetExtendKind = [](VPRecipeBase *R) {
342- if (!R)
343- return TTI::PR_None;
344- auto *WidenCastR = dyn_cast<VPWidenCastRecipe>(R);
345- if (!WidenCastR)
346- return TTI::PR_None;
347- if (WidenCastR->getOpcode () == Instruction::CastOps::ZExt)
348- return TTI::PR_ZeroExtend;
349- if (WidenCastR->getOpcode () == Instruction::CastOps::SExt)
350- return TTI::PR_SignExtend;
351- return TTI::PR_None;
352- };
353-
354- // Pick out opcode, type/ext information and use sub side effects from a widen
355- // recipe.
356- auto HandleWiden = [&](VPWidenRecipe *Widen) {
357- if (match (Widen, m_Sub (m_ZeroInt (), m_VPValue (Op)))) {
358- Widen = dyn_cast<VPWidenRecipe>(Op);
359- }
360- Opcode = Widen->getOpcode ();
361- VPRecipeBase *ExtAR = Widen->getOperand (0 )->getDefiningRecipe ();
362- VPRecipeBase *ExtBR = Widen->getOperand (1 )->getDefiningRecipe ();
363- InputTypeA = Ctx.Types .inferScalarType (ExtAR ? ExtAR->getOperand (0 )
364- : Widen->getOperand (0 ));
365- InputTypeB = Ctx.Types .inferScalarType (ExtBR ? ExtBR->getOperand (0 )
366- : Widen->getOperand (1 ));
367- ExtAType = GetExtendKind (ExtAR);
368- ExtBType = GetExtendKind (ExtBR);
369-
370- using namespace VPlanPatternMatch ;
371- const APInt *C;
372- if (!ExtBR && match (Widen->getOperand (1 ), m_APInt (C)) &&
373- canConstantBeExtended (C, InputTypeA, ExtAType)) {
374- InputTypeB = InputTypeA;
375- ExtBType = ExtAType;
376- }
377- };
378-
379- if (isa<VPWidenCastRecipe>(OpR)) {
380- InputTypeA = Ctx.Types .inferScalarType (OpR->getOperand (0 ));
381- ExtAType = GetExtendKind (OpR);
382- } else if (isa<VPReductionPHIRecipe>(OpR)) {
383- if (auto RedPhiOp1R = dyn_cast_or_null<VPWidenCastRecipe>(getOperand (1 ))) {
384- InputTypeA = Ctx.Types .inferScalarType (RedPhiOp1R->getOperand (0 ));
385- ExtAType = GetExtendKind (RedPhiOp1R);
386- } else if (auto Widen = dyn_cast_or_null<VPWidenRecipe>(getOperand (1 )))
387- HandleWiden (Widen);
388- } else if (auto Widen = dyn_cast<VPWidenRecipe>(OpR)) {
389- HandleWiden (Widen);
390- } else if (auto Reduction = dyn_cast<VPPartialReductionRecipe>(OpR)) {
391- return CondCost + Reduction->computeCost (VF, Ctx);
392- }
393- auto *PhiType = Ctx.Types .inferScalarType (getOperand (1 ));
394- return CondCost + Ctx.TTI .getPartialReductionCost (
395- getOpcode (), InputTypeA, InputTypeB, PhiType, VF,
396- ExtAType, ExtBType, Opcode, Ctx.CostKind );
397- ;
310+ auto *PhiType = Ctx.Types .inferScalarType (getChainOp ());
311+ auto *InputType = Ctx.Types .inferScalarType (getVecOp ());
312+ return Ctx.TTI .getPartialReductionCost (getOpcode (), InputType, InputType,
313+ PhiType, VF, TTI::PR_None,
314+ TTI::PR_None, {}, Ctx.CostKind );
398315}
399316
400317void VPPartialReductionRecipe::execute (VPTransformState &State) {
0 commit comments