@@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
5632
5632
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned > BinOp,
5633
5633
TTI::TargetCostKind CostKind) const {
5634
5634
InstructionCost Invalid = InstructionCost::getInvalid ();
5635
- InstructionCost Cost (TTI::TCC_Basic);
5636
5635
5637
5636
if (CostKind != TTI::TCK_RecipThroughput)
5638
5637
return Invalid;
5639
5638
5640
- // Sub opcodes currently only occur in chained cases.
5641
- // Independent partial reduction subtractions are still costed as an add
5639
+ if (VF.isFixed () && !ST->isSVEorStreamingSVEAvailable () &&
5640
+ (!ST->isNeonAvailable () || !ST->hasDotProd ()))
5641
+ return Invalid;
5642
+
5642
5643
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
5643
5644
OpAExtend == TTI::PR_None)
5644
5645
return Invalid;
5645
5646
5647
+ assert ((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5648
+ (!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) &&
5649
+ " Unexpected values for OpBExtend or InputTypeB" );
5650
+
5646
5651
// We only support multiply binary operations for now, and for muls we
5647
5652
// require the types being extended to be the same.
5648
- // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5649
- // only if the i8mm or sve/streaming features are available.
5650
- if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5651
- OpBExtend == TTI::PR_None ||
5652
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
5653
- !ST->isSVEorStreamingSVEAvailable ())))
5653
+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
5654
5654
return Invalid;
5655
- assert ((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5656
- " Unexpected values for OpBExtend or InputTypeB" );
5657
5655
5658
- EVT InputEVT = EVT::getEVT (InputTypeA);
5659
- EVT AccumEVT = EVT::getEVT (AccumType);
5656
+ bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
5657
+ if (IsUSDot && !ST->hasMatMulInt8 ())
5658
+ return Invalid;
5659
+
5660
+ unsigned Ratio =
5661
+ AccumType->getScalarSizeInBits () / InputTypeA->getScalarSizeInBits ();
5662
+ if (VF.getKnownMinValue () <= Ratio)
5663
+ return Invalid;
5664
+
5665
+ VectorType *InputVectorType = VectorType::get (InputTypeA, VF);
5666
+ VectorType *AccumVectorType =
5667
+ VectorType::get (AccumType, VF.divideCoefficientBy (Ratio));
5668
+ // We don't yet support all kinds of legalization.
5669
+ auto TA = TLI->getTypeAction (AccumVectorType->getContext (),
5670
+ EVT::getEVT (AccumVectorType));
5671
+ switch (TA) {
5672
+ default :
5673
+ return Invalid;
5674
+ case TargetLowering::TypeLegal:
5675
+ case TargetLowering::TypePromoteInteger:
5676
+ case TargetLowering::TypeSplitVector:
5677
+ break ;
5678
+ }
5679
+
5680
+ // Check what kind of type-legalisation happens.
5681
+ std::pair<InstructionCost, MVT> AccumLT =
5682
+ getTypeLegalizationCost (AccumVectorType);
5683
+ std::pair<InstructionCost, MVT> InputLT =
5684
+ getTypeLegalizationCost (InputVectorType);
5660
5685
5661
- unsigned VFMinValue = VF. getKnownMinValue () ;
5686
+ InstructionCost Cost = InputLT. first * TTI::TCC_Basic ;
5662
5687
5663
- if (VF.isScalable ()) {
5664
- if (!ST->isSVEorStreamingSVEAvailable ())
5665
- return Invalid;
5688
+ // Prefer using full types by costing half-full input types as more expensive.
5689
+ if (TypeSize::isKnownLT (InputVectorType->getPrimitiveSizeInBits (),
5690
+ TypeSize::getScalable (128 )))
5691
+ // FIXME: This can be removed after the cost of the extends are folded into
5692
+ // the dot-product expression in VPlan, after landing:
5693
+ // https://github.com/llvm/llvm-project/pull/147302
5694
+ Cost *= 2 ;
5666
5695
5667
- // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5668
- // since we can't lower that type.
5669
- unsigned Scale =
5670
- AccumEVT.getScalarSizeInBits () / InputEVT.getScalarSizeInBits ();
5671
- if (VFMinValue == Scale)
5672
- return Invalid;
5696
+ if (ST->isSVEorStreamingSVEAvailable () && !IsUSDot) {
5697
+ // i16 -> i64 is natively supported for udot/sdot
5698
+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5699
+ InputLT.second .getScalarType () == MVT::i16 )
5700
+ return Cost;
5701
+ // i8 -> i64 is supported with an extra level of extends
5702
+ if (AccumLT.second .getScalarType () == MVT::i64 &&
5703
+ InputLT.second .getScalarType () == MVT::i8 )
5704
+ // FIXME: This cost should probably be a little higher, e.g. Cost + 2
5705
+ // because it requires two extra extends on the inputs. But if we'd change
5706
+ // that now, a regular reduction would be cheaper because the costs of
5707
+ // the extends in the IR are still counted. This can be fixed
5708
+ // after https://github.com/llvm/llvm-project/pull/147302 has landed.
5709
+ return Cost;
5673
5710
}
5674
- if (VF.isFixed () &&
5675
- (!ST->isNeonAvailable () || !ST->hasDotProd () || AccumEVT == MVT::i64 ))
5676
- return Invalid;
5677
5711
5678
- if (InputEVT == MVT::i8 ) {
5679
- switch (VFMinValue) {
5680
- default :
5681
- return Invalid;
5682
- case 8 :
5683
- if (AccumEVT == MVT::i32 )
5684
- Cost *= 2 ;
5685
- else if (AccumEVT != MVT::i64 )
5686
- return Invalid;
5687
- break ;
5688
- case 16 :
5689
- if (AccumEVT == MVT::i64 )
5690
- Cost *= 2 ;
5691
- else if (AccumEVT != MVT::i32 )
5692
- return Invalid;
5693
- break ;
5694
- }
5695
- } else if (InputEVT == MVT::i16 ) {
5696
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
5697
- // it to i64.
5698
- if (VFMinValue != 8 || AccumEVT != MVT::i64 )
5699
- return Invalid;
5700
- } else
5701
- return Invalid;
5712
+ // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
5713
+ if (ST->isSVEorStreamingSVEAvailable () ||
5714
+ (AccumLT.second .isFixedLengthVector () && ST->isNeonAvailable () &&
5715
+ ST->hasDotProd ())) {
5716
+ if (AccumLT.second .getScalarType () == MVT::i32 &&
5717
+ InputLT.second .getScalarType () == MVT::i8 )
5718
+ return Cost;
5719
+ }
5702
5720
5703
- return Cost;
5721
+ // Add additional cost for the extends that would need to be inserted.
5722
+ return Cost + 4 ;
5704
5723
}
5705
5724
5706
5725
InstructionCost
0 commit comments