Skip to content

Commit 549aa8b

Browse files
committed
[AArch64] Improve the cost model for extending mull
We already have cost model code for detecting extending mull multiplies for the form `mul(ext, ext)`. Since it was added the codegen for mull has been improved, this attempts to catch the cost model up. The main idea is to incorporate extends of larger sizes. A vector `v8i32 mul(zext(v8i8), zext(v8i8))` will be code-generated as `zext (v8i16 mul(zext(v8i8), zext(v8i8))`, or ushll+ushll2+umull. So the total cost should be 3ish if each instruction costs 1. Where exactly we attribute the costs is dependable, this patch opts to sets the cost of the extend to 0 (or the cost of the extend not included in the mull) and the mul gets the cost of the mull+extra extends. isWideningInstruction is split into two functions for the two types of operands it supports. isSingleExtWideningInstruction now handles addw instructions that extend the second operand, isBinExtWideningInstruction is for instructions like addl that extend both operands. The changes in the partial reduction tests show that they need a better cost model, that treats the mul + extends as free for the dot. It would be best to fix that first.
1 parent 077e0c1 commit 549aa8b

File tree

7 files changed

+1267
-965
lines changed

7 files changed

+1267
-965
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 118 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,9 +2585,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
25852585
llvm_unreachable("Unsupported register kind");
25862586
}
25872587

2588-
bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2589-
ArrayRef<const Value *> Args,
2590-
Type *SrcOverrideTy) {
2588+
bool AArch64TTIImpl::isSingleExtWideningInstruction(
2589+
unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
2590+
Type *SrcOverrideTy) {
25912591
// A helper that returns a vector type from the given type. The number of
25922592
// elements in type Ty determines the vector width.
25932593
auto toVectorTy = [&](Type *ArgTy) {
@@ -2605,48 +2605,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
26052605
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
26062606
return false;
26072607

2608-
// Determine if the operation has a widening variant. We consider both the
2609-
// "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2610-
// instructions.
2611-
//
2612-
// TODO: Add additional widening operations (e.g., shl, etc.) once we
2613-
// verify that their extending operands are eliminated during code
2614-
// generation.
26152608
Type *SrcTy = SrcOverrideTy;
26162609
switch (Opcode) {
2617-
case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2618-
case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2610+
case Instruction::Add: // UADDW(2), SADDW(2).
2611+
case Instruction::Sub: { // USUBW(2), SSUBW(2).
26192612
// The second operand needs to be an extend
26202613
if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
26212614
if (!SrcTy)
26222615
SrcTy =
26232616
toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2624-
} else
2617+
break;
2618+
}
2619+
2620+
if (Opcode == Instruction::Sub)
26252621
return false;
2626-
break;
2627-
case Instruction::Mul: { // SMULL(2), UMULL(2)
2628-
// Both operands need to be extends of the same type.
2629-
if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2630-
(isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2622+
2623+
// UADDW(2), SADDW(2) can be commutted.
2624+
if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) {
26312625
if (!SrcTy)
26322626
SrcTy =
26332627
toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2634-
} else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2635-
// If one of the operands is a Zext and the other has enough zero bits to
2636-
// be treated as unsigned, we can still general a umull, meaning the zext
2637-
// is free.
2638-
KnownBits Known =
2639-
computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2640-
if (Args[0]->getType()->getScalarSizeInBits() -
2641-
Known.Zero.countLeadingOnes() >
2642-
DstTy->getScalarSizeInBits() / 2)
2643-
return false;
2644-
if (!SrcTy)
2645-
SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2646-
DstTy->getScalarSizeInBits() / 2));
2647-
} else
2648-
return false;
2649-
break;
2628+
break;
2629+
}
2630+
return false;
26502631
}
26512632
default:
26522633
return false;
@@ -2677,6 +2658,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
26772658
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
26782659
}
26792660

2661+
Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
2662+
ArrayRef<const Value *> Args,
2663+
Type *SrcOverrideTy) {
2664+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
2665+
Opcode != Instruction::Mul)
2666+
return nullptr;
2667+
2668+
// Exit early if DstTy is not a vector type whose elements are one of [i16,
2669+
// i32, i64]. SVE doesn't generally have the same set of instructions to
2670+
// perform an extend with the add/sub/mul. There are SMULLB style
2671+
// instructions, but they operate on top/bottom, requiring some sort of lane
2672+
// interleaving to be used with zext/sext.
2673+
unsigned DstEltSize = DstTy->getScalarSizeInBits();
2674+
if (!useNeonVector(DstTy) || Args.size() != 2 ||
2675+
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2676+
return nullptr;
2677+
2678+
auto getScalarSizeWithOverride = [&](const Value *V) {
2679+
if (SrcOverrideTy)
2680+
return SrcOverrideTy->getScalarSizeInBits();
2681+
return cast<Instruction>(V)
2682+
->getOperand(0)
2683+
->getType()
2684+
->getScalarSizeInBits();
2685+
};
2686+
2687+
unsigned MaxEltSize = 0;
2688+
if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2689+
(isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2690+
unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
2691+
unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
2692+
MaxEltSize = std::max(EltSize0, EltSize1);
2693+
} else if (isa<SExtInst, ZExtInst>(Args[0]) &&
2694+
isa<SExtInst, ZExtInst>(Args[1])) {
2695+
unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
2696+
unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
2697+
// mul(sext, zext) will become smull(sext, zext) if the extends are large
2698+
// enough.
2699+
if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2)
2700+
return nullptr;
2701+
MaxEltSize = DstEltSize / 2;
2702+
} else if (Opcode == Instruction::Mul &&
2703+
(isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) {
2704+
// If one of the operands is a Zext and the other has enough zero bits
2705+
// to be treated as unsigned, we can still generate a umull, meaning the
2706+
// zext is free.
2707+
KnownBits Known =
2708+
computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2709+
if (Args[0]->getType()->getScalarSizeInBits() -
2710+
Known.Zero.countLeadingOnes() >
2711+
DstTy->getScalarSizeInBits() / 2)
2712+
return nullptr;
2713+
2714+
MaxEltSize =
2715+
getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]);
2716+
} else
2717+
return nullptr;
2718+
2719+
if (MaxEltSize * 2 > DstEltSize)
2720+
return nullptr;
2721+
2722+
Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2);
2723+
if (ExtTy->getPrimitiveSizeInBits() <= 64)
2724+
return nullptr;
2725+
return ExtTy;
2726+
}
2727+
26802728
// s/urhadd instructions implement the following pattern, making the
26812729
// extends free:
26822730
// %x = add ((zext i8 -> i16), 1)
@@ -2737,7 +2785,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
27372785
if (I && I->hasOneUser()) {
27382786
auto *SingleUser = cast<Instruction>(*I->user_begin());
27392787
SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
2740-
if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
2788+
if (Type *ExtTy = isBinExtWideningInstruction(
2789+
SingleUser->getOpcode(), Dst, Operands,
2790+
Src != I->getOperand(0)->getType() ? Src : nullptr)) {
2791+
// The cost from Src->Src*2 needs to be added if required, the cost from
2792+
// Src*2->ExtTy is free.
2793+
if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) {
2794+
Type *DoubleSrcTy =
2795+
Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2);
2796+
return getCastInstrCost(Opcode, DoubleSrcTy, Src,
2797+
TTI::CastContextHint::None, CostKind);
2798+
}
2799+
2800+
return 0;
2801+
}
2802+
2803+
if (isSingleExtWideningInstruction(
2804+
SingleUser->getOpcode(), Dst, Operands,
2805+
Src != I->getOperand(0)->getType() ? Src : nullptr)) {
27412806
// For adds only count the second operand as free if both operands are
27422807
// extends but not the same operation. (i.e both operands are not free in
27432808
// add(sext, zext)).
@@ -2746,8 +2811,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
27462811
(isa<CastInst>(SingleUser->getOperand(1)) &&
27472812
cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
27482813
return 0;
2749-
} else // Others are free so long as isWideningInstruction returned true.
2814+
} else {
2815+
// Others are free so long as isSingleExtWideningInstruction
2816+
// returned true.
27502817
return 0;
2818+
}
27512819
}
27522820

27532821
// The cast will be free for the s/urhadd instructions
@@ -3496,6 +3564,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
34963564
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
34973565
int ISD = TLI->InstructionOpcodeToISD(Opcode);
34983566

3567+
// If the operation is a widening instruction (smull or umull) and both
3568+
// operands are extends the cost can be cheaper by considering that the
3569+
// operation will operate on the narrowest type size possible (double the
3570+
// largest input size) and a further extend.
3571+
if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) {
3572+
if (ExtTy != Ty)
3573+
return getArithmeticInstrCost(Opcode, ExtTy, CostKind) +
3574+
getCastInstrCost(Instruction::ZExt, Ty, ExtTy,
3575+
TTI::CastContextHint::None, CostKind);
3576+
return LT.first;
3577+
}
3578+
34993579
switch (ISD) {
35003580
default:
35013581
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -3613,10 +3693,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
36133693
// - two 2-cost i64 inserts, and
36143694
// - two 1-cost muls.
36153695
// So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
3616-
// LT.first = 2 the cost is 28. If both operands are extensions it will not
3617-
// need to scalarize so the cost can be cheaper (smull or umull).
3618-
// so the cost can be cheaper (smull or umull).
3619-
if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
3696+
// LT.first = 2 the cost is 28.
3697+
if (LT.second != MVT::v2i64)
36203698
return LT.first;
36213699
return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
36223700
(getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,17 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
5757
VECTOR_LDST_FOUR_ELEMENTS
5858
};
5959

60-
bool isWideningInstruction(Type *DstTy, unsigned Opcode,
61-
ArrayRef<const Value *> Args,
62-
Type *SrcOverrideTy = nullptr);
60+
/// Given a add/sub/mul operation, detect a widening addl/subl/mull pattern
61+
/// where both operands can be treated like extends. Returns the minimal type
62+
/// needed to compute the operation.
63+
Type *isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
64+
ArrayRef<const Value *> Args,
65+
Type *SrcOverrideTy = nullptr);
66+
/// Given a add/sub operation with a single extend operand, detect a
67+
/// widening addw/subw pattern.
68+
bool isSingleExtWideningInstruction(unsigned Opcode, Type *DstTy,
69+
ArrayRef<const Value *> Args,
70+
Type *SrcOverrideTy = nullptr);
6371

6472
// A helper function called by 'getVectorInstrCost'.
6573
//

0 commit comments

Comments
 (0)