@@ -4664,6 +4664,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
46644664 return LegalizationCost * LT.first ;
46654665}
46664666
4667+ InstructionCost AArch64TTIImpl::getPartialReductionCost (
4668+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
4669+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
4670+ TTI::PartialReductionExtendKind OpBExtend,
4671+ std::optional<unsigned > BinOp) const {
4672+ InstructionCost Invalid = InstructionCost::getInvalid ();
4673+ InstructionCost Cost (TTI::TCC_Basic);
4674+
4675+ if (Opcode != Instruction::Add)
4676+ return Invalid;
4677+
4678+ if (InputTypeA != InputTypeB)
4679+ return Invalid;
4680+
4681+ EVT InputEVT = EVT::getEVT (InputTypeA);
4682+ EVT AccumEVT = EVT::getEVT (AccumType);
4683+
4684+ if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
4685+ return Invalid;
4686+ if (VF.isFixed () && (!ST->isNeonAvailable () || !ST->hasDotProd ()))
4687+ return Invalid;
4688+
4689+ if (InputEVT == MVT::i8 ) {
4690+ switch (VF.getKnownMinValue ()) {
4691+ default :
4692+ return Invalid;
4693+ case 8 :
4694+ if (AccumEVT == MVT::i32 )
4695+ Cost *= 2 ;
4696+ else if (AccumEVT != MVT::i64 )
4697+ return Invalid;
4698+ break ;
4699+ case 16 :
4700+ if (AccumEVT == MVT::i64 )
4701+ Cost *= 2 ;
4702+ else if (AccumEVT != MVT::i32 )
4703+ return Invalid;
4704+ break ;
4705+ }
4706+ } else if (InputEVT == MVT::i16 ) {
4707+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
4708+ // it to i64.
4709+ if (VF.getKnownMinValue () != 8 || AccumEVT != MVT::i64 )
4710+ return Invalid;
4711+ } else
4712+ return Invalid;
4713+
4714+ // AArch64 supports lowering mixed extensions to a usdot but only if the
4715+ // i8mm or sve/streaming features are available.
4716+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
4717+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
4718+ !ST->isSVEorStreamingSVEAvailable ()))
4719+ return Invalid;
4720+
4721+ if (!BinOp || *BinOp != Instruction::Mul)
4722+ return Invalid;
4723+
4724+ return Cost;
4725+ }
4726+
46674727InstructionCost AArch64TTIImpl::getShuffleCost (
46684728 TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int > Mask,
46694729 TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
@@ -5573,64 +5633,3 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
55735633 }
55745634 return false ;
55755635}
5576-
5577- InstructionCost
5578- AArch64TTIImpl::getPartialReductionCost (unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
5579- Type *AccumType, ElementCount VF,
5580- TTI::PartialReductionExtendKind OpAExtend,
5581- TTI::PartialReductionExtendKind OpBExtend,
5582- std::optional<unsigned > BinOp) const {
5583- InstructionCost Invalid = InstructionCost::getInvalid ();
5584- InstructionCost Cost (TTI::TCC_Basic);
5585-
5586- if (Opcode != Instruction::Add)
5587- return Invalid;
5588-
5589- if (InputTypeA != InputTypeB)
5590- return Invalid;
5591-
5592- EVT InputEVT = EVT::getEVT (InputTypeA);
5593- EVT AccumEVT = EVT::getEVT (AccumType);
5594-
5595- if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
5596- return Invalid;
5597- if (VF.isFixed () && (!ST->isNeonAvailable () || !ST->hasDotProd ()))
5598- return Invalid;
5599-
5600- if (InputEVT == MVT::i8 ) {
5601- switch (VF.getKnownMinValue ()) {
5602- default :
5603- return Invalid;
5604- case 8 :
5605- if (AccumEVT == MVT::i32 )
5606- Cost *= 2 ;
5607- else if (AccumEVT != MVT::i64 )
5608- return Invalid;
5609- break ;
5610- case 16 :
5611- if (AccumEVT == MVT::i64 )
5612- Cost *= 2 ;
5613- else if (AccumEVT != MVT::i32 )
5614- return Invalid;
5615- break ;
5616- }
5617- } else if (InputEVT == MVT::i16 ) {
5618- // FIXME: Allow i32 accumulator but increase cost, as we would extend
5619- // it to i64.
5620- if (VF.getKnownMinValue () != 8 || AccumEVT != MVT::i64 )
5621- return Invalid;
5622- } else
5623- return Invalid;
5624-
5625- // AArch64 supports lowering mixed extensions to a usdot but only if the
5626- // i8mm or sve/streaming features are available.
5627- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5628- (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
5629- !ST->isSVEorStreamingSVEAvailable ()))
5630- return Invalid;
5631-
5632- if (!BinOp || *BinOp != Instruction::Mul)
5633- return Invalid;
5634-
5635- return Cost;
5636- }
0 commit comments