Skip to content

Commit 13f83b0

Browse files
committed
[AArch64] Improve urem by constant costs
A urem by a constant, much like a udiv by a constant, can be expanded into a series of mul/add/shift instructions. The exact sequence of instructions depends on the constants and the types. If the constant is a power-2 then a shift / and will be used, so the cost will be 1. This canonicalization happens relatively early so this likely has very little effect in practice (it does help the cost of funnel shifts). For a non-power 2 the code for div will expand to a series of UMULH + Add + Shift + Add, depending on the constant. urem is generally udiv + mul + sub, so involves a few extra instructions. The UMULH is not always available, i32 will use umull+shift, and vector types will use umull+shift or umull+umull2+uzp depending on the vector size. v2i64 will be scalarized because there is no mul available. SVE does have a UMULH instruction. The end result is that the costs should be closer to reality, with scalable types a little lower cost than the fixed-width versions. (In the future we might be able to use umulh for fixed-width when the SVE instruction is available, but for the moment this should favour scalable vectorization a little). I've tried to make this patch only apply to constant UREM/UDIV instructions. SDIV and SREM are left until a later patch to prevent this becoming too complex. The funnel shift costs are changing as it believes it will need a urem to clamp the shift amount, which should be a power-2 value for most common types.
1 parent f22441c commit 13f83b0

File tree

10 files changed

+607
-566
lines changed

10 files changed

+607
-566
lines changed

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,8 @@ TargetTransformInfo::getOperandInfo(const Value *V) {
893893

894894
// Check for a splat of a constant or for a non uniform vector of constants
895895
// and check if the constant(s) are all powers of two.
896-
if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
896+
if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V) ||
897+
isa<ConstantExpr>(V)) {
897898
OpInfo = OK_NonUniformConstantValue;
898899
if (Splat) {
899900
OpInfo = OK_UniformConstantValue;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3510,21 +3510,61 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
35103510
return Cost;
35113511
}
35123512
[[fallthrough]];
3513-
case ISD::UDIV: {
3513+
case ISD::UDIV:
3514+
case ISD::UREM: {
35143515
auto VT = TLI->getValueType(DL, Ty);
3515-
if (Op2Info.isConstant() && Op2Info.isUniform()) {
3516-
if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
3516+
if (Op2Info.isConstant()) {
3517+
// If the operand is a power of 2 we can use the shift or and cost.
3518+
if (ISD == ISD::UDIV && Op2Info.isPowerOf2())
3519+
return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
3520+
Op1Info.getNoProps(),
3521+
Op2Info.getNoProps());
3522+
if (ISD == ISD::UREM && Op2Info.isPowerOf2())
3523+
return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
3524+
Op1Info.getNoProps(),
3525+
Op2Info.getNoProps());
3526+
3527+
if (ISD == ISD::UDIV || ISD == ISD::UREM) {
3528+
// Divides by a constant are expanded to MULHU + SUB + SRL + ADD + SRL.
3529+
// The MULHU will be expanded to UMULL for the types not listed below,
3530+
// and will become a pair of UMULL+MULL2 for 128bit vectors.
3531+
bool HasMULH = VT == MVT::i64 || LT.second == MVT::nxv2i64 ||
3532+
LT.second == MVT::nxv4i32 || LT.second == MVT::nxv8i16 ||
3533+
LT.second == MVT::nxv16i8;
3534+
bool Is128bit = LT.second.is128BitVector();
3535+
3536+
InstructionCost MulCost =
3537+
getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
3538+
Op1Info.getNoProps(), Op2Info.getNoProps());
3539+
InstructionCost AddCost =
3540+
getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
3541+
Op1Info.getNoProps(), Op2Info.getNoProps());
3542+
InstructionCost ShrCost =
3543+
getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
3544+
Op1Info.getNoProps(), Op2Info.getNoProps());
3545+
InstructionCost DivCost = MulCost * (Is128bit ? 2 : 1) + // UMULL/UMULH
3546+
(HasMULH ? 0 : ShrCost) + // UMULL shift
3547+
AddCost * 2 + ShrCost;
3548+
return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
3549+
}
3550+
3551+
// TODOD: Fix SDIV and SREM costs, similar to the above.
3552+
if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT) &&
3553+
Op2Info.isUniform()) {
35173554
// Vector signed division by constant are expanded to the
3518-
// sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
3519-
// to MULHS + SUB + SRL + ADD + SRL.
3520-
InstructionCost MulCost = getArithmeticInstrCost(
3521-
Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3522-
InstructionCost AddCost = getArithmeticInstrCost(
3523-
Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3524-
InstructionCost ShrCost = getArithmeticInstrCost(
3525-
Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
3555+
// sequence MULHS + ADD/SUB + SRA + SRL + ADD.
3556+
InstructionCost MulCost =
3557+
getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
3558+
Op1Info.getNoProps(), Op2Info.getNoProps());
3559+
InstructionCost AddCost =
3560+
getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
3561+
Op1Info.getNoProps(), Op2Info.getNoProps());
3562+
InstructionCost ShrCost =
3563+
getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
3564+
Op1Info.getNoProps(), Op2Info.getNoProps());
35263565
return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
35273566
}
3567+
35283568
}
35293569

35303570
// div i128's are lowered as libcalls. Pass nullptr as (u)divti3 calls are
@@ -3535,7 +3575,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
35353575

35363576
InstructionCost Cost = BaseT::getArithmeticInstrCost(
35373577
Opcode, Ty, CostKind, Op1Info, Op2Info);
3538-
if (Ty->isVectorTy()) {
3578+
if (Ty->isVectorTy() && (ISD == ISD::SDIV || ISD == ISD::UDIV)) {
35393579
if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
35403580
// SDIV/UDIV operations are lowered using SVE, then we can have less
35413581
// costs.

0 commit comments

Comments
 (0)