Skip to content

Commit b58b094

Browse files
committed
[AArch64] Refactor and refine cost-model for partial reductions
This cost-model takes into account any type-legalisation that would happen on vectors such as splitting and promotion. This results in wider VFs being chosen for loops that can use partial reductions. The cost-model now also assumes that when SVE is available, the SVE dot instructions for i16 -> i64 dot products can be used for fixed-length vectors. In practice this means that loops with non-scalable VFs are vectorized using partial reductions where they wouldn't before, e.g. int64_t foo2(int8_t *src1, int8_t *src2, int N) { int64_t sum = 0; for (int i=0; i<N; ++i) sum += (int64_t)src1[i] * (int64_t)src2[i]; return sum; } These changes also fix an issue where previously a partial reduction would be used for mixed sign/zero-extends (USDOT), even when +i8mm was not available.
1 parent c4e1bca commit b58b094

File tree

10 files changed

+548
-671
lines changed

10 files changed

+548
-671
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5632,75 +5632,88 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
56325632
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
56335633
TTI::TargetCostKind CostKind) const {
56345634
InstructionCost Invalid = InstructionCost::getInvalid();
5635-
InstructionCost Cost(TTI::TCC_Basic);
56365635

56375636
if (CostKind != TTI::TCK_RecipThroughput)
56385637
return Invalid;
56395638

5640-
// Sub opcodes currently only occur in chained cases.
5641-
// Independent partial reduction subtractions are still costed as an add
5639+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
5640+
return Invalid;
5641+
5642+
if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
5643+
(!ST->isNeonAvailable() || !ST->hasDotProd()))
5644+
return Invalid;
5645+
56425646
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
56435647
OpAExtend == TTI::PR_None)
56445648
return Invalid;
56455649

56465650
// We only support multiply binary operations for now, and for muls we
56475651
// require the types being extended to be the same.
5648-
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5649-
// only if the i8mm or sve/streaming features are available.
5650-
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5651-
OpBExtend == TTI::PR_None ||
5652-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5653-
!ST->isSVEorStreamingSVEAvailable())))
5652+
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
56545653
return Invalid;
56555654
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
56565655
"Unexpected values for OpBExtend or InputTypeB");
56575656

5658-
EVT InputEVT = EVT::getEVT(InputTypeA);
5659-
EVT AccumEVT = EVT::getEVT(AccumType);
5657+
bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
5658+
if (IsUSDot && !ST->hasMatMulInt8())
5659+
return Invalid;
56605660

5661-
unsigned VFMinValue = VF.getKnownMinValue();
5661+
unsigned Ratio =
5662+
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
5663+
if (VF.getKnownMinValue() < Ratio)
5664+
return Invalid;
56625665

5663-
if (VF.isScalable()) {
5664-
if (!ST->isSVEorStreamingSVEAvailable())
5665-
return Invalid;
5666+
VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
5667+
VectorType *AccumVectorType =
5668+
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
56665669

5667-
// Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5668-
// since we can't lower that type.
5669-
unsigned Scale =
5670-
AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
5671-
if (VFMinValue == Scale)
5672-
return Invalid;
5673-
}
5674-
if (VF.isFixed() &&
5675-
(!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
5670+
// We don't yet support widening for <vscale x 1 x ..> accumulators.
5671+
if (AccumVectorType->getElementCount() == ElementCount::getScalable(1))
56765672
return Invalid;
56775673

5678-
if (InputEVT == MVT::i8) {
5679-
switch (VFMinValue) {
5680-
default:
5681-
return Invalid;
5682-
case 8:
5683-
if (AccumEVT == MVT::i32)
5684-
Cost *= 2;
5685-
else if (AccumEVT != MVT::i64)
5686-
return Invalid;
5687-
break;
5688-
case 16:
5689-
if (AccumEVT == MVT::i64)
5690-
Cost *= 2;
5691-
else if (AccumEVT != MVT::i32)
5692-
return Invalid;
5693-
break;
5694-
}
5695-
} else if (InputEVT == MVT::i16) {
5696-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
5697-
// it to i64.
5698-
if (VFMinValue != 8 || AccumEVT != MVT::i64)
5699-
return Invalid;
5700-
} else
5701-
return Invalid;
5674+
// Check what kind of type-legalisation happens.
5675+
std::pair<InstructionCost, MVT> AccumLT =
5676+
getTypeLegalizationCost(AccumVectorType);
5677+
std::pair<InstructionCost, MVT> InputLT =
5678+
getTypeLegalizationCost(InputVectorType);
57025679

5703-
return Cost;
5680+
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
5681+
5682+
// Prefer using full types by costing half-full input types as more expensive.
5683+
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
5684+
TypeSize::getScalable(128)))
5685+
// FIXME: This can be removed after the cost of the extends are folded into
5686+
// the dot-product expression in VPlan, after landing:
5687+
// https://github.com/llvm/llvm-project/pull/147302
5688+
Cost *= 2;
5689+
5690+
if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
5691+
// i16 -> i64 is natively supported for udot/sdot
5692+
if (AccumLT.second.getScalarType() == MVT::i64 &&
5693+
InputLT.second.getScalarType() == MVT::i16)
5694+
return Cost;
5695+
// i8 -> i64 is supported with an extra level of extends
5696+
if (AccumLT.second.getScalarType() == MVT::i64 &&
5697+
InputLT.second.getScalarType() == MVT::i8)
5698+
// FIXME: This cost should probably be a little higher, e.g. Cost + 2
5699+
// because it requires two extra extends on the inputs. But if we'd change
5700+
// that now, a regular reduction would be cheaper because the costs of
5701+
// the extends in the IR are still counted. This can be fixed
5702+
// after https://github.com/llvm/llvm-project/pull/147302 has landed.
5703+
return Cost;
5704+
}
5705+
5706+
// i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
5707+
if (ST->isSVEorStreamingSVEAvailable() ||
5708+
(AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
5709+
ST->hasDotProd())) {
5710+
if (AccumLT.second.getScalarType() == MVT::i32 &&
5711+
InputLT.second.getScalarType() == MVT::i8)
5712+
return Cost;
5713+
}
5714+
5715+
// Add additional cost for the extends that would need to be inserted.
5716+
return Cost + 4;
57045717
}
57055718

57065719
InstructionCost

llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ define i64 @test_two_ivs(ptr %a, ptr %b, i64 %start) #0 {
8282
; CHECK-NEXT: Cost of 0 for VF 8: induction instruction %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
8383
; CHECK-NEXT: Cost of 1 for VF 8: exit condition instruction %exitcond.not = icmp eq i64 %i.iv.next, 16
8484
; CHECK-NEXT: Cost of 0 for VF 8: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
85-
; CHECK: Cost for VF 8: 27
85+
; CHECK: Cost for VF 8: 25
8686
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %i.iv = phi i64 [ 0, %entry ], [ %i.iv.next, %for.body ]
8787
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
8888
; CHECK-NEXT: Cost of 0 for VF 16: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
89-
; CHECK: Cost for VF 16: 48
89+
; CHECK: Cost for VF 16: 41
9090
; CHECK: LV: Selecting VF: 16
9191
entry:
9292
br label %for.body

0 commit comments

Comments
 (0)