@@ -4635,6 +4635,54 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
46354635 return BaseT::getArithmeticReductionCost (Opcode, ValTy, FMF, CostKind);
46364636}
46374637
4638+ InstructionCost AArch64TTIImpl::getExtendedReductionCost (
4639+ unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
4640+ FastMathFlags FMF, TTI::TargetCostKind CostKind) {
4641+ EVT VecVT = TLI->getValueType (DL, VecTy);
4642+ EVT ResVT = TLI->getValueType (DL, ResTy);
4643+
4644+ if (Opcode == Instruction::Add && VecVT.isSimple () && ResVT.isSimple () &&
4645+ VecVT.getSizeInBits () >= 64 ) {
4646+ std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (VecTy);
4647+
4648+ // The legal cases are:
4649+ // UADDLV 8/16/32->32
4650+ // UADDLP 32->64
4651+ unsigned RevVTSize = ResVT.getSizeInBits ();
4652+ if (((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4653+ RevVTSize <= 32 ) ||
4654+ ((LT.second == MVT::v4i16 || LT.second == MVT::v8i16) &&
4655+ RevVTSize <= 32 ) ||
4656+ ((LT.second == MVT::v2i32 || LT.second == MVT::v4i32) &&
4657+ RevVTSize <= 64 ))
4658+ return (LT.first - 1 ) * 2 + 2 ;
4659+ }
4660+
4661+ return BaseT::getExtendedReductionCost (Opcode, IsUnsigned, ResTy, VecTy, FMF,
4662+ CostKind);
4663+ }
4664+
4665+ InstructionCost
4666+ AArch64TTIImpl::getMulAccReductionCost (bool IsUnsigned, Type *ResTy,
4667+ VectorType *VecTy,
4668+ TTI::TargetCostKind CostKind) {
4669+ EVT VecVT = TLI->getValueType (DL, VecTy);
4670+ EVT ResVT = TLI->getValueType (DL, ResTy);
4671+
4672+ if (ST->hasDotProd () && VecVT.isSimple () && ResVT.isSimple ()) {
4673+ std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (VecTy);
4674+
4675+ // The legal cases with dotprod are
4676+ // UDOT 8->32
4677+ // Which requires an additional uaddv to sum the i32 values.
4678+ if ((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4679+ ResVT == MVT::i32 )
4680+ return LT.first + 2 ;
4681+ }
4682+
4683+ return BaseT::getMulAccReductionCost (IsUnsigned, ResTy, VecTy, CostKind);
4684+ }
4685+
46384686InstructionCost AArch64TTIImpl::getSpliceCost (VectorType *Tp, int Index) {
46394687 static const CostTblEntry ShuffleTbl[] = {
46404688 { TTI::SK_Splice, MVT::nxv16i8, 1 },
0 commit comments