Skip to content

Commit bbb25c7

Browse files
sdesmalen-armgithub-actions[bot]
authored andcommitted
Automerge: [AArch64] Refactor and refine cost-model for partial reductions (#158641)
This cost-model takes into account any type-legalisation that would happen on vectors such as splitting and promotion. This results in wider VFs being chosen for loops that can use partial reductions. The cost-model now also assumes that when SVE is available, the SVE dot instructions for i16 -> i64 dot products can be used for fixed-length vectors. In practice this means that loops with non-scalable VFs are vectorized using partial reductions where they wouldn't before, e.g. ``` int64_t foo2(int8_t *src1, int8_t *src2, int N) { int64_t sum = 0; for (int i=0; i<N; ++i) sum += (int64_t)src1[i] * (int64_t)src2[i]; return sum; } ``` These changes also fix an issue where previously a partial reduction would be used for mixed sign/zero-extends (USDOT), even when +i8mm was not available.
2 parents b790558 + cc9c64d commit bbb25c7

9 files changed

+549
-666
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
56325632
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
56335633
TTI::TargetCostKind CostKind) const {
56345634
InstructionCost Invalid = InstructionCost::getInvalid();
5635-
InstructionCost Cost(TTI::TCC_Basic);
56365635

56375636
if (CostKind != TTI::TCK_RecipThroughput)
56385637
return Invalid;
56395638

5640-
// Sub opcodes currently only occur in chained cases.
5641-
// Independent partial reduction subtractions are still costed as an add
5639+
if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
5640+
(!ST->isNeonAvailable() || !ST->hasDotProd()))
5641+
return Invalid;
5642+
56425643
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
56435644
OpAExtend == TTI::PR_None)
56445645
return Invalid;
56455646

5647+
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5648+
(!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) &&
5649+
"Unexpected values for OpBExtend or InputTypeB");
5650+
56465651
// We only support multiply binary operations for now, and for muls we
56475652
// require the types being extended to be the same.
5648-
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5649-
// only if the i8mm or sve/streaming features are available.
5650-
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5651-
OpBExtend == TTI::PR_None ||
5652-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5653-
!ST->isSVEorStreamingSVEAvailable())))
5653+
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
56545654
return Invalid;
5655-
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5656-
"Unexpected values for OpBExtend or InputTypeB");
56575655

5658-
EVT InputEVT = EVT::getEVT(InputTypeA);
5659-
EVT AccumEVT = EVT::getEVT(AccumType);
5656+
bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
5657+
if (IsUSDot && !ST->hasMatMulInt8())
5658+
return Invalid;
5659+
5660+
unsigned Ratio =
5661+
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
5662+
if (VF.getKnownMinValue() <= Ratio)
5663+
return Invalid;
5664+
5665+
VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
5666+
VectorType *AccumVectorType =
5667+
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
5668+
// We don't yet support all kinds of legalization.
5669+
auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
5670+
EVT::getEVT(AccumVectorType));
5671+
switch (TA) {
5672+
default:
5673+
return Invalid;
5674+
case TargetLowering::TypeLegal:
5675+
case TargetLowering::TypePromoteInteger:
5676+
case TargetLowering::TypeSplitVector:
5677+
break;
5678+
}
5679+
5680+
// Check what kind of type-legalisation happens.
5681+
std::pair<InstructionCost, MVT> AccumLT =
5682+
getTypeLegalizationCost(AccumVectorType);
5683+
std::pair<InstructionCost, MVT> InputLT =
5684+
getTypeLegalizationCost(InputVectorType);
56605685

5661-
unsigned VFMinValue = VF.getKnownMinValue();
5686+
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
56625687

5663-
if (VF.isScalable()) {
5664-
if (!ST->isSVEorStreamingSVEAvailable())
5665-
return Invalid;
5688+
// Prefer using full types by costing half-full input types as more expensive.
5689+
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
5690+
TypeSize::getScalable(128)))
5691+
// FIXME: This can be removed after the cost of the extends are folded into
5692+
// the dot-product expression in VPlan, after landing:
5693+
// https://github.com/llvm/llvm-project/pull/147302
5694+
Cost *= 2;
56665695

5667-
// Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5668-
// since we can't lower that type.
5669-
unsigned Scale =
5670-
AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
5671-
if (VFMinValue == Scale)
5672-
return Invalid;
5696+
if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
5697+
// i16 -> i64 is natively supported for udot/sdot
5698+
if (AccumLT.second.getScalarType() == MVT::i64 &&
5699+
InputLT.second.getScalarType() == MVT::i16)
5700+
return Cost;
5701+
// i8 -> i64 is supported with an extra level of extends
5702+
if (AccumLT.second.getScalarType() == MVT::i64 &&
5703+
InputLT.second.getScalarType() == MVT::i8)
5704+
// FIXME: This cost should probably be a little higher, e.g. Cost + 2
5705+
// because it requires two extra extends on the inputs. But if we'd change
5706+
// that now, a regular reduction would be cheaper because the costs of
5707+
// the extends in the IR are still counted. This can be fixed
5708+
// after https://github.com/llvm/llvm-project/pull/147302 has landed.
5709+
return Cost;
56735710
}
5674-
if (VF.isFixed() &&
5675-
(!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
5676-
return Invalid;
56775711

5678-
if (InputEVT == MVT::i8) {
5679-
switch (VFMinValue) {
5680-
default:
5681-
return Invalid;
5682-
case 8:
5683-
if (AccumEVT == MVT::i32)
5684-
Cost *= 2;
5685-
else if (AccumEVT != MVT::i64)
5686-
return Invalid;
5687-
break;
5688-
case 16:
5689-
if (AccumEVT == MVT::i64)
5690-
Cost *= 2;
5691-
else if (AccumEVT != MVT::i32)
5692-
return Invalid;
5693-
break;
5694-
}
5695-
} else if (InputEVT == MVT::i16) {
5696-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
5697-
// it to i64.
5698-
if (VFMinValue != 8 || AccumEVT != MVT::i64)
5699-
return Invalid;
5700-
} else
5701-
return Invalid;
5712+
// i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
5713+
if (ST->isSVEorStreamingSVEAvailable() ||
5714+
(AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
5715+
ST->hasDotProd())) {
5716+
if (AccumLT.second.getScalarType() == MVT::i32 &&
5717+
InputLT.second.getScalarType() == MVT::i8)
5718+
return Cost;
5719+
}
57025720

5703-
return Cost;
5721+
// Add additional cost for the extends that would need to be inserted.
5722+
return Cost + 4;
57045723
}
57055724

57065725
InstructionCost

llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ define i64 @test_two_ivs(ptr %a, ptr %b, i64 %start) #0 {
8686
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %i.iv = phi i64 [ 0, %entry ], [ %i.iv.next, %for.body ]
8787
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
8888
; CHECK-NEXT: Cost of 0 for VF 16: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
89-
; CHECK: Cost for VF 16: 48
89+
; CHECK: Cost for VF 16: 41
9090
; CHECK: LV: Selecting VF: 16
9191
entry:
9292
br label %for.body

0 commit comments

Comments
 (0)