diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 68aec80f07e1d..9a00bf2dbc924 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -413,6 +413,20 @@ AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, return std::max(1, Cost); } +static bool isLegalArithImmed(uint64_t C) { + // Matches AArch64DAGToDAGISel::SelectArithImmed(). + bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0); + LLVM_DEBUG(dbgs() << "Is imm " << C + << " legal: " << (IsLegal ? "yes\n" : "no\n")); + return IsLegal; +} + +static bool isLegalCmpImmed(APInt C) { + // Works for negative immediates too, as it can be written as an ADDS + // instruction with a negated immediate. + return isLegalArithImmed(C.abs().getZExtValue()); +} + InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind, @@ -473,10 +487,19 @@ InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, if (Idx == ImmIdx) { int NumConstants = (BitSize + 63) / 64; - InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); - return (Cost <= NumConstants * TTI::TCC_Basic) - ? static_cast(TTI::TCC_Free) - : Cost; + InstructionCost Cost; + if ((Opcode == Instruction::Add || Opcode == Instruction::Sub || + Opcode == Instruction::ICmp) && + BitSize <= 64) { + // Add/Sub/ICmp immediates can be flipped. + // Also they have different requirements as to fitting in an immediate + // than others. + if (isLegalCmpImmed(Imm)) + return TTI::TCC_Free; + Cost = AArch64TTIImpl::getIntImmCost(Imm.abs(), Ty, CostKind); + } else + Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); + return (Cost <= NumConstants * TTI::TCC_Basic) ? TTI::TCC_Free : Cost; } return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); } @@ -511,9 +534,7 @@ AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, if (Idx == 1) { int NumConstants = (BitSize + 63) / 64; InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); - return (Cost <= NumConstants * TTI::TCC_Basic) - ? static_cast(TTI::TCC_Free) - : Cost; + return (Cost <= NumConstants * TTI::TCC_Basic) ? TTI::TCC_Free : Cost; } break; case Intrinsic::experimental_stackmap: