@@ -5632,75 +5632,88 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
56325632 TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned > BinOp,
56335633 TTI::TargetCostKind CostKind) const {
56345634 InstructionCost Invalid = InstructionCost::getInvalid ();
5635- InstructionCost Cost (TTI::TCC_Basic);
56365635
56375636 if (CostKind != TTI::TCK_RecipThroughput)
56385637 return Invalid;
56395638
5640- // Sub opcodes currently only occur in chained cases.
5641- // Independent partial reduction subtractions are still costed as an add
5639+ if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
5640+ return Invalid;
5641+
5642+ if (VF.isFixed () && !ST->isSVEorStreamingSVEAvailable () &&
5643+ (!ST->isNeonAvailable () || !ST->hasDotProd ()))
5644+ return Invalid;
5645+
56425646 if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
56435647 OpAExtend == TTI::PR_None)
56445648 return Invalid;
56455649
56465650 // We only support multiply binary operations for now, and for muls we
56475651 // require the types being extended to be the same.
5648- // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5649- // only if the i8mm or sve/streaming features are available.
5650- if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5651- OpBExtend == TTI::PR_None ||
5652- (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
5653- !ST->isSVEorStreamingSVEAvailable ())))
5652+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
56545653 return Invalid;
56555654 assert ((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
56565655 " Unexpected values for OpBExtend or InputTypeB" );
56575656
5658- EVT InputEVT = EVT::getEVT (InputTypeA);
5659- EVT AccumEVT = EVT::getEVT (AccumType);
5657+ bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
5658+ if (IsUSDot && !ST->hasMatMulInt8 ())
5659+ return Invalid;
56605660
5661- unsigned VFMinValue = VF.getKnownMinValue ();
5661+ unsigned Ratio =
5662+ AccumType->getScalarSizeInBits () / InputTypeA->getScalarSizeInBits ();
5663+ if (VF.getKnownMinValue () < Ratio)
5664+ return Invalid;
56625665
5663- if (VF. isScalable ()) {
5664- if (!ST-> isSVEorStreamingSVEAvailable ())
5665- return Invalid ;
5666+ VectorType *InputVectorType = VectorType::get (InputTypeA, VF);
5667+ VectorType *AccumVectorType =
5668+ VectorType::get (AccumType, VF. divideCoefficientBy (Ratio)) ;
56665669
5667- // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5668- // since we can't lower that type.
5669- unsigned Scale =
5670- AccumEVT.getScalarSizeInBits () / InputEVT.getScalarSizeInBits ();
5671- if (VFMinValue == Scale)
5672- return Invalid;
5673- }
5674- if (VF.isFixed () &&
5675- (!ST->isNeonAvailable () || !ST->hasDotProd () || AccumEVT == MVT::i64 ))
5670+ // We don't yet support widening for <vscale x 1 x ..> accumulators.
5671+ if (AccumVectorType->getElementCount () == ElementCount::getScalable (1 ))
56765672 return Invalid;
56775673
5678- if (InputEVT == MVT::i8 ) {
5679- switch (VFMinValue) {
5680- default :
5681- return Invalid;
5682- case 8 :
5683- if (AccumEVT == MVT::i32 )
5684- Cost *= 2 ;
5685- else if (AccumEVT != MVT::i64 )
5686- return Invalid;
5687- break ;
5688- case 16 :
5689- if (AccumEVT == MVT::i64 )
5690- Cost *= 2 ;
5691- else if (AccumEVT != MVT::i32 )
5692- return Invalid;
5693- break ;
5694- }
5695- } else if (InputEVT == MVT::i16 ) {
5696- // FIXME: Allow i32 accumulator but increase cost, as we would extend
5697- // it to i64.
5698- if (VFMinValue != 8 || AccumEVT != MVT::i64 )
5699- return Invalid;
5700- } else
5701- return Invalid;
5674+ // Check what kind of type-legalisation happens.
5675+ std::pair<InstructionCost, MVT> AccumLT =
5676+ getTypeLegalizationCost (AccumVectorType);
5677+ std::pair<InstructionCost, MVT> InputLT =
5678+ getTypeLegalizationCost (InputVectorType);
57025679
5703- return Cost;
5680+ InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
5681+
5682+ // Prefer using full types by costing half-full input types as more expensive.
5683+ if (TypeSize::isKnownLT (InputVectorType->getPrimitiveSizeInBits (),
5684+ TypeSize::getScalable (128 )))
5685+ // FIXME: This can be removed after the cost of the extends are folded into
5686+ // the dot-product expression in VPlan, after landing:
5687+ // https://github.com/llvm/llvm-project/pull/147302
5688+ Cost *= 2 ;
5689+
5690+ if (ST->isSVEorStreamingSVEAvailable () && !IsUSDot) {
5691+ // i16 -> i64 is natively supported for udot/sdot
5692+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5693+ InputLT.second .getScalarType () == MVT::i16 )
5694+ return Cost;
5695+ // i8 -> i64 is supported with an extra level of extends
5696+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5697+ InputLT.second .getScalarType () == MVT::i8 )
5698+ // FIXME: This cost should probably be a little higher, e.g. Cost + 2
5699+ // because it requires two extra extends on the inputs. But if we'd change
5700+ // that now, a regular reduction would be cheaper because the costs of
5701+ // the extends in the IR are still counted. This can be fixed
5702+ // after https://github.com/llvm/llvm-project/pull/147302 has landed.
5703+ return Cost;
5704+ }
5705+
5706+ // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
5707+ if (ST->isSVEorStreamingSVEAvailable () ||
5708+ (AccumLT.second .isFixedLengthVector () && ST->isNeonAvailable () &&
5709+ ST->hasDotProd ())) {
5710+ if (AccumLT.second .getScalarType () == MVT::i32 &&
5711+ InputLT.second .getScalarType () == MVT::i8 )
5712+ return Cost;
5713+ }
5714+
5715+ // Add additional cost for the extends that would need to be inserted.
5716+ return Cost + 4 ;
57045717}
57055718
57065719InstructionCost
0 commit comments