diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 7f10bfed739b4..f32e4d91c833e 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -4664,6 +4664,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) { return LegalizationCost * LT.first; } +InstructionCost AArch64TTIImpl::getPartialReductionCost( + unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, + ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, + TTI::PartialReductionExtendKind OpBExtend, + std::optional BinOp) const { + InstructionCost Invalid = InstructionCost::getInvalid(); + InstructionCost Cost(TTI::TCC_Basic); + + if (Opcode != Instruction::Add) + return Invalid; + + if (InputTypeA != InputTypeB) + return Invalid; + + EVT InputEVT = EVT::getEVT(InputTypeA); + EVT AccumEVT = EVT::getEVT(AccumType); + + if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable()) + return Invalid; + if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd())) + return Invalid; + + if (InputEVT == MVT::i8) { + switch (VF.getKnownMinValue()) { + default: + return Invalid; + case 8: + if (AccumEVT == MVT::i32) + Cost *= 2; + else if (AccumEVT != MVT::i64) + return Invalid; + break; + case 16: + if (AccumEVT == MVT::i64) + Cost *= 2; + else if (AccumEVT != MVT::i32) + return Invalid; + break; + } + } else if (InputEVT == MVT::i16) { + // FIXME: Allow i32 accumulator but increase cost, as we would extend + // it to i64. + if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64) + return Invalid; + } else + return Invalid; + + // AArch64 supports lowering mixed extensions to a usdot but only if the + // i8mm or sve/streaming features are available. + if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None || + (OpAExtend != OpBExtend && !ST->hasMatMulInt8() && + !ST->isSVEorStreamingSVEAvailable())) + return Invalid; + + if (!BinOp || *BinOp != Instruction::Mul) + return Invalid; + + return Cost; +} + InstructionCost AArch64TTIImpl::getShuffleCost( TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef Mask, TTI::TargetCostKind CostKind, int Index, VectorType *SubTp, diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 1eb805ae00b1b..b65e3c7a1ab20 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase { Type *AccumType, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const { - - InstructionCost Invalid = InstructionCost::getInvalid(); - InstructionCost Cost(TTI::TCC_Basic); - - if (Opcode != Instruction::Add) - return Invalid; - - if (InputTypeA != InputTypeB) - return Invalid; - - EVT InputEVT = EVT::getEVT(InputTypeA); - EVT AccumEVT = EVT::getEVT(AccumType); - - if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable()) - return Invalid; - if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd())) - return Invalid; - - if (InputEVT == MVT::i8) { - switch (VF.getKnownMinValue()) { - default: - return Invalid; - case 8: - if (AccumEVT == MVT::i32) - Cost *= 2; - else if (AccumEVT != MVT::i64) - return Invalid; - break; - case 16: - if (AccumEVT == MVT::i64) - Cost *= 2; - else if (AccumEVT != MVT::i32) - return Invalid; - break; - } - } else if (InputEVT == MVT::i16) { - // FIXME: Allow i32 accumulator but increase cost, as we would extend - // it to i64. - if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64) - return Invalid; - } else - return Invalid; - - // AArch64 supports lowering mixed extensions to a usdot but only if the - // i8mm or sve/streaming features are available. - if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None || - (OpAExtend != OpBExtend && !ST->hasMatMulInt8() && - !ST->isSVEorStreamingSVEAvailable())) - return Invalid; - - if (!BinOp || *BinOp != Instruction::Mul) - return Invalid; - - return Cost; - } + std::optional BinOp) const; bool enableOrderedReductions() const { return true; }