@@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
56325632 TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned > BinOp,
56335633 TTI::TargetCostKind CostKind) const {
56345634 InstructionCost Invalid = InstructionCost::getInvalid ();
5635- InstructionCost Cost (TTI::TCC_Basic);
56365635
56375636 if (CostKind != TTI::TCK_RecipThroughput)
56385637 return Invalid;
56395638
5640- // Sub opcodes currently only occur in chained cases.
5641- // Independent partial reduction subtractions are still costed as an add
5639+ if (VF.isFixed () && !ST->isSVEorStreamingSVEAvailable () &&
5640+ (!ST->isNeonAvailable () || !ST->hasDotProd ()))
5641+ return Invalid;
5642+
56425643 if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
56435644 OpAExtend == TTI::PR_None)
56445645 return Invalid;
56455646
5647+ assert ((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5648+ (!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) &&
5649+ " Unexpected values for OpBExtend or InputTypeB" );
5650+
56465651 // We only support multiply binary operations for now, and for muls we
56475652 // require the types being extended to be the same.
5648- // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5649- // only if the i8mm or sve/streaming features are available.
5650- if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5651- OpBExtend == TTI::PR_None ||
5652- (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
5653- !ST->isSVEorStreamingSVEAvailable ())))
5653+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
56545654 return Invalid;
5655- assert ((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5656- " Unexpected values for OpBExtend or InputTypeB" );
56575655
5658- EVT InputEVT = EVT::getEVT (InputTypeA);
5659- EVT AccumEVT = EVT::getEVT (AccumType);
5656+ bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
5657+ if (IsUSDot && !ST->hasMatMulInt8 ())
5658+ return Invalid;
5659+
5660+ unsigned Ratio =
5661+ AccumType->getScalarSizeInBits () / InputTypeA->getScalarSizeInBits ();
5662+ if (VF.getKnownMinValue () <= Ratio)
5663+ return Invalid;
5664+
5665+ VectorType *InputVectorType = VectorType::get (InputTypeA, VF);
5666+ VectorType *AccumVectorType =
5667+ VectorType::get (AccumType, VF.divideCoefficientBy (Ratio));
5668+ // We don't yet support all kinds of legalization.
5669+ auto TA = TLI->getTypeAction (AccumVectorType->getContext (),
5670+ EVT::getEVT (AccumVectorType));
5671+ switch (TA) {
5672+ default :
5673+ return Invalid;
5674+ case TargetLowering::TypeLegal:
5675+ case TargetLowering::TypePromoteInteger:
5676+ case TargetLowering::TypeSplitVector:
5677+ break ;
5678+ }
5679+
5680+ // Check what kind of type-legalisation happens.
5681+ std::pair<InstructionCost, MVT> AccumLT =
5682+ getTypeLegalizationCost (AccumVectorType);
5683+ std::pair<InstructionCost, MVT> InputLT =
5684+ getTypeLegalizationCost (InputVectorType);
56605685
5661- unsigned VFMinValue = VF. getKnownMinValue () ;
5686+ InstructionCost Cost = InputLT. first * TTI::TCC_Basic ;
56625687
5663- if (VF.isScalable ()) {
5664- if (!ST->isSVEorStreamingSVEAvailable ())
5665- return Invalid;
5688+ // Prefer using full types by costing half-full input types as more expensive.
5689+ if (TypeSize::isKnownLT (InputVectorType->getPrimitiveSizeInBits (),
5690+ TypeSize::getScalable (128 )))
5691+ // FIXME: This can be removed after the cost of the extends are folded into
5692+ // the dot-product expression in VPlan, after landing:
5693+ // https://github.com/llvm/llvm-project/pull/147302
5694+ Cost *= 2 ;
56665695
5667- // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5668- // since we can't lower that type.
5669- unsigned Scale =
5670- AccumEVT.getScalarSizeInBits () / InputEVT.getScalarSizeInBits ();
5671- if (VFMinValue == Scale)
5672- return Invalid;
5696+ if (ST->isSVEorStreamingSVEAvailable () && !IsUSDot) {
5697+ // i16 -> i64 is natively supported for udot/sdot
5698+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5699+ InputLT.second .getScalarType () == MVT::i16 )
5700+ return Cost;
5701+ // i8 -> i64 is supported with an extra level of extends
5702+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5703+ InputLT.second .getScalarType () == MVT::i8 )
5704+ // FIXME: This cost should probably be a little higher, e.g. Cost + 2
5705+ // because it requires two extra extends on the inputs. But if we'd change
5706+ // that now, a regular reduction would be cheaper because the costs of
5707+ // the extends in the IR are still counted. This can be fixed
5708+ // after https://github.com/llvm/llvm-project/pull/147302 has landed.
5709+ return Cost;
56735710 }
5674- if (VF.isFixed () &&
5675- (!ST->isNeonAvailable () || !ST->hasDotProd () || AccumEVT == MVT::i64 ))
5676- return Invalid;
56775711
5678- if (InputEVT == MVT::i8 ) {
5679- switch (VFMinValue) {
5680- default :
5681- return Invalid;
5682- case 8 :
5683- if (AccumEVT == MVT::i32 )
5684- Cost *= 2 ;
5685- else if (AccumEVT != MVT::i64 )
5686- return Invalid;
5687- break ;
5688- case 16 :
5689- if (AccumEVT == MVT::i64 )
5690- Cost *= 2 ;
5691- else if (AccumEVT != MVT::i32 )
5692- return Invalid;
5693- break ;
5694- }
5695- } else if (InputEVT == MVT::i16 ) {
5696- // FIXME: Allow i32 accumulator but increase cost, as we would extend
5697- // it to i64.
5698- if (VFMinValue != 8 || AccumEVT != MVT::i64 )
5699- return Invalid;
5700- } else
5701- return Invalid;
5712+ // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
5713+ if (ST->isSVEorStreamingSVEAvailable () ||
5714+ (AccumLT.second .isFixedLengthVector () && ST->isNeonAvailable () &&
5715+ ST->hasDotProd ())) {
5716+ if (AccumLT.second .getScalarType () == MVT::i32 &&
5717+ InputLT.second .getScalarType () == MVT::i8 )
5718+ return Cost;
5719+ }
57025720
5703- return Cost;
5721+ // Add additional cost for the extends that would need to be inserted.
5722+ return Cost + 4 ;
57045723}
57055724
57065725InstructionCost
0 commit comments