Skip to content

Commit 6ad25c5

Browse files
authored
[AArch64] Improve the cost model for extending mull (#125651)
We already have cost model code for detecting extending mull multiplies for the form `mul(ext, ext)`. Since it was added the codegen for mull has been improved, this attempts to catch the cost model up. The main idea is to incorporate extends of larger sizes. A vector `v8i32 mul(zext(v8i8), zext(v8i8))` will be code-generated as `zext (v8i16 mul(zext(v8i8), zext(v8i8))`, or umull+ushll+ushll2. So the total cost should be 3ish if each instruction costs 1. Where exactly we attribute the costs is dependable, this patch opts to sets the cost of the extend to 0 (or the cost of the extend not included in the mull) and the mul gets the cost of the mull+extra extends. isWideningInstruction is split into two functions for the two types of operands it supports. isSingleExtWideningInstruction now handles addw instructions that extend the second operand, isBinExtWideningInstruction is for instructions like addl that extend both operands.
1 parent 25a592c commit 6ad25c5

File tree

8 files changed

+460
-453
lines changed

8 files changed

+460
-453
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 118 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3007,9 +3007,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
30073007
llvm_unreachable("Unsupported register kind");
30083008
}
30093009

3010-
bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
3011-
ArrayRef<const Value *> Args,
3012-
Type *SrcOverrideTy) const {
3010+
bool AArch64TTIImpl::isSingleExtWideningInstruction(
3011+
unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
3012+
Type *SrcOverrideTy) const {
30133013
// A helper that returns a vector type from the given type. The number of
30143014
// elements in type Ty determines the vector width.
30153015
auto toVectorTy = [&](Type *ArgTy) {
@@ -3027,48 +3027,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
30273027
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
30283028
return false;
30293029

3030-
// Determine if the operation has a widening variant. We consider both the
3031-
// "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
3032-
// instructions.
3033-
//
3034-
// TODO: Add additional widening operations (e.g., shl, etc.) once we
3035-
// verify that their extending operands are eliminated during code
3036-
// generation.
30373030
Type *SrcTy = SrcOverrideTy;
30383031
switch (Opcode) {
3039-
case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
3040-
case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
3032+
case Instruction::Add: // UADDW(2), SADDW(2).
3033+
case Instruction::Sub: { // USUBW(2), SSUBW(2).
30413034
// The second operand needs to be an extend
30423035
if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
30433036
if (!SrcTy)
30443037
SrcTy =
30453038
toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
3046-
} else
3039+
break;
3040+
}
3041+
3042+
if (Opcode == Instruction::Sub)
30473043
return false;
3048-
break;
3049-
case Instruction::Mul: { // SMULL(2), UMULL(2)
3050-
// Both operands need to be extends of the same type.
3051-
if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
3052-
(isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
3044+
3045+
// UADDW(2), SADDW(2) can be commutted.
3046+
if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) {
30533047
if (!SrcTy)
30543048
SrcTy =
30553049
toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
3056-
} else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
3057-
// If one of the operands is a Zext and the other has enough zero bits to
3058-
// be treated as unsigned, we can still general a umull, meaning the zext
3059-
// is free.
3060-
KnownBits Known =
3061-
computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
3062-
if (Args[0]->getType()->getScalarSizeInBits() -
3063-
Known.Zero.countLeadingOnes() >
3064-
DstTy->getScalarSizeInBits() / 2)
3065-
return false;
3066-
if (!SrcTy)
3067-
SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
3068-
DstTy->getScalarSizeInBits() / 2));
3069-
} else
3070-
return false;
3071-
break;
3050+
break;
3051+
}
3052+
return false;
30723053
}
30733054
default:
30743055
return false;
@@ -3099,6 +3080,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
30993080
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
31003081
}
31013082

3083+
Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
3084+
ArrayRef<const Value *> Args,
3085+
Type *SrcOverrideTy) const {
3086+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
3087+
Opcode != Instruction::Mul)
3088+
return nullptr;
3089+
3090+
// Exit early if DstTy is not a vector type whose elements are one of [i16,
3091+
// i32, i64]. SVE doesn't generally have the same set of instructions to
3092+
// perform an extend with the add/sub/mul. There are SMULLB style
3093+
// instructions, but they operate on top/bottom, requiring some sort of lane
3094+
// interleaving to be used with zext/sext.
3095+
unsigned DstEltSize = DstTy->getScalarSizeInBits();
3096+
if (!useNeonVector(DstTy) || Args.size() != 2 ||
3097+
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
3098+
return nullptr;
3099+
3100+
auto getScalarSizeWithOverride = [&](const Value *V) {
3101+
if (SrcOverrideTy)
3102+
return SrcOverrideTy->getScalarSizeInBits();
3103+
return cast<Instruction>(V)
3104+
->getOperand(0)
3105+
->getType()
3106+
->getScalarSizeInBits();
3107+
};
3108+
3109+
unsigned MaxEltSize = 0;
3110+
if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
3111+
(isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
3112+
unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
3113+
unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
3114+
MaxEltSize = std::max(EltSize0, EltSize1);
3115+
} else if (isa<SExtInst, ZExtInst>(Args[0]) &&
3116+
isa<SExtInst, ZExtInst>(Args[1])) {
3117+
unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
3118+
unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
3119+
// mul(sext, zext) will become smull(sext, zext) if the extends are large
3120+
// enough.
3121+
if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2)
3122+
return nullptr;
3123+
MaxEltSize = DstEltSize / 2;
3124+
} else if (Opcode == Instruction::Mul &&
3125+
(isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) {
3126+
// If one of the operands is a Zext and the other has enough zero bits
3127+
// to be treated as unsigned, we can still generate a umull, meaning the
3128+
// zext is free.
3129+
KnownBits Known =
3130+
computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
3131+
if (Args[0]->getType()->getScalarSizeInBits() -
3132+
Known.Zero.countLeadingOnes() >
3133+
DstTy->getScalarSizeInBits() / 2)
3134+
return nullptr;
3135+
3136+
MaxEltSize =
3137+
getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]);
3138+
} else
3139+
return nullptr;
3140+
3141+
if (MaxEltSize * 2 > DstEltSize)
3142+
return nullptr;
3143+
3144+
Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2);
3145+
if (ExtTy->getPrimitiveSizeInBits() <= 64)
3146+
return nullptr;
3147+
return ExtTy;
3148+
}
3149+
31023150
// s/urhadd instructions implement the following pattern, making the
31033151
// extends free:
31043152
// %x = add ((zext i8 -> i16), 1)
@@ -3159,7 +3207,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31593207
if (I && I->hasOneUser()) {
31603208
auto *SingleUser = cast<Instruction>(*I->user_begin());
31613209
SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
3162-
if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
3210+
if (Type *ExtTy = isBinExtWideningInstruction(
3211+
SingleUser->getOpcode(), Dst, Operands,
3212+
Src != I->getOperand(0)->getType() ? Src : nullptr)) {
3213+
// The cost from Src->Src*2 needs to be added if required, the cost from
3214+
// Src*2->ExtTy is free.
3215+
if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) {
3216+
Type *DoubleSrcTy =
3217+
Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2);
3218+
return getCastInstrCost(Opcode, DoubleSrcTy, Src,
3219+
TTI::CastContextHint::None, CostKind);
3220+
}
3221+
3222+
return 0;
3223+
}
3224+
3225+
if (isSingleExtWideningInstruction(
3226+
SingleUser->getOpcode(), Dst, Operands,
3227+
Src != I->getOperand(0)->getType() ? Src : nullptr)) {
31633228
// For adds only count the second operand as free if both operands are
31643229
// extends but not the same operation. (i.e both operands are not free in
31653230
// add(sext, zext)).
@@ -3168,8 +3233,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31683233
(isa<CastInst>(SingleUser->getOperand(1)) &&
31693234
cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
31703235
return 0;
3171-
} else // Others are free so long as isWideningInstruction returned true.
3236+
} else {
3237+
// Others are free so long as isSingleExtWideningInstruction
3238+
// returned true.
31723239
return 0;
3240+
}
31733241
}
31743242

31753243
// The cast will be free for the s/urhadd instructions
@@ -4148,6 +4216,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
41484216
}))
41494217
return *PromotedCost;
41504218

4219+
// If the operation is a widening instruction (smull or umull) and both
4220+
// operands are extends the cost can be cheaper by considering that the
4221+
// operation will operate on the narrowest type size possible (double the
4222+
// largest input size) and a further extend.
4223+
if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) {
4224+
if (ExtTy != Ty)
4225+
return getArithmeticInstrCost(Opcode, ExtTy, CostKind) +
4226+
getCastInstrCost(Instruction::ZExt, Ty, ExtTy,
4227+
TTI::CastContextHint::None, CostKind);
4228+
return LT.first;
4229+
}
4230+
41514231
switch (ISD) {
41524232
default:
41534233
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -4381,10 +4461,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
43814461
// - two 2-cost i64 inserts, and
43824462
// - two 1-cost muls.
43834463
// So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4384-
// LT.first = 2 the cost is 28. If both operands are extensions it will not
4385-
// need to scalarize so the cost can be cheaper (smull or umull).
4386-
// so the cost can be cheaper (smull or umull).
4387-
if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
4464+
// LT.first = 2 the cost is 28.
4465+
if (LT.second != MVT::v2i64)
43884466
return LT.first;
43894467
return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
43904468
(getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,17 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
5959
VECTOR_LDST_FOUR_ELEMENTS
6060
};
6161

62-
bool isWideningInstruction(Type *DstTy, unsigned Opcode,
63-
ArrayRef<const Value *> Args,
64-
Type *SrcOverrideTy = nullptr) const;
62+
/// Given a add/sub/mul operation, detect a widening addl/subl/mull pattern
63+
/// where both operands can be treated like extends. Returns the minimal type
64+
/// needed to compute the operation.
65+
Type *isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
66+
ArrayRef<const Value *> Args,
67+
Type *SrcOverrideTy = nullptr) const;
68+
/// Given a add/sub operation with a single extend operand, detect a
69+
/// widening addw/subw pattern.
70+
bool isSingleExtWideningInstruction(unsigned Opcode, Type *DstTy,
71+
ArrayRef<const Value *> Args,
72+
Type *SrcOverrideTy = nullptr) const;
6573

6674
// A helper function called by 'getVectorInstrCost'.
6775
//

0 commit comments

Comments
 (0)