@@ -3376,8 +3376,8 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
33763376InstructionCost AArch64TTIImpl::getArithmeticInstrCost (
33773377 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
33783378 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
3379- ArrayRef<const Value *> Args,
3380- const Instruction *CxtI ) {
3379+ ArrayRef<const Value *> Args, const Instruction *CxtI,
3380+ ArrayRef<Value *> Scalars ) {
33813381
33823382 // The code-generator is currently not able to handle scalable vectors
33833383 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
@@ -3442,8 +3442,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
34423442 if (!VT.isVector () && VT.getSizeInBits () > 64 )
34433443 return getCallInstrCost (/* Function*/ nullptr , Ty, {Ty, Ty}, CostKind);
34443444
3445- InstructionCost Cost = BaseT::getArithmeticInstrCost (
3446- Opcode, Ty, CostKind, Op1Info, Op2Info);
3445+ InstructionCost Cost =
3446+ BaseT::getArithmeticInstrCost ( Opcode, Ty, CostKind, Op1Info, Op2Info);
34473447 if (Ty->isVectorTy ()) {
34483448 if (TLI->isOperationLegalOrCustom (ISD, LT.second ) && ST->hasSVE ()) {
34493449 // SDIV/UDIV operations are lowered using SVE, then we can have less
@@ -3472,29 +3472,41 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
34723472 Cost *= 4 ;
34733473 return Cost;
34743474 } else {
3475- // If one of the operands is a uniform constant then the cost for each
3476- // element is Cost for insertion, extraction and division.
3477- // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
3478- // operation with scalar type
3479- if ((Op1Info.isConstant () && Op1Info.isUniform ()) ||
3480- (Op2Info.isConstant () && Op2Info.isUniform ())) {
3481- if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
3475+ if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
3476+ if ((Op1Info.isConstant () && Op1Info.isUniform ()) ||
3477+ (Op2Info.isConstant () && Op2Info.isUniform ())) {
34823478 InstructionCost DivCost = BaseT::getArithmeticInstrCost (
34833479 Opcode, Ty->getScalarType (), CostKind, Op1Info, Op2Info);
3484- return (4 + DivCost) * VTy->getNumElements ();
3480+ // If #vector_elements = n then we need
3481+ // n inserts + 2n extracts + n divisions.
3482+ InstructionCost InsertExtractCost =
3483+ ST->getVectorInsertExtractBaseCost ();
3484+ Cost = (3 * InsertExtractCost + DivCost) * VTy->getNumElements ();
3485+ } else if (!Scalars.empty ()) {
3486+ // If #vector_elements = n then we need
3487+ // n inserts + 2n extracts + n divisions.
3488+ InstructionCost InsertExtractCost =
3489+ ST->getVectorInsertExtractBaseCost ();
3490+ Cost = (3 * InsertExtractCost) * VTy->getNumElements ();
3491+ for (auto *V : Scalars) {
3492+ auto *I = cast<Instruction>(V);
3493+ Cost +=
3494+ getArithmeticInstrCost (I->getOpcode (), I->getType (), CostKind,
3495+ TTI::getOperandInfo (I->getOperand (0 )),
3496+ TTI::getOperandInfo (I->getOperand (1 )));
3497+ }
3498+ } else {
3499+ // FIXME: The initial cost calculated should have considered extract
3500+ // cost twice. For now, we just add additional cost to avoid
3501+ // underestimating the total cost.
3502+ Cost += Cost;
34853503 }
3504+ } else {
3505+ // We can't predict the cost of div/extract/insert without knowing the
3506+ // vector width.
3507+ Cost.setInvalid ();
34863508 }
3487- // On AArch64, without SVE, vector divisions are expanded
3488- // into scalar divisions of each pair of elements.
3489- Cost += getArithmeticInstrCost (Instruction::ExtractElement, Ty,
3490- CostKind, Op1Info, Op2Info);
3491- Cost += getArithmeticInstrCost (Instruction::InsertElement, Ty, CostKind,
3492- Op1Info, Op2Info);
34933509 }
3494-
3495- // TODO: if one of the arguments is scalar, then it's not necessary to
3496- // double the cost of handling the vector elements.
3497- Cost += Cost;
34983510 }
34993511 return Cost;
35003512 }
0 commit comments