Skip to content

Commit 0829f8c

Browse files
committed
Address comments
1 parent f5a11ac commit 0829f8c

File tree

5 files changed

+195
-185
lines changed

5 files changed

+195
-185
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/CodeGen/BasicTTIImpl.h"
1919
#include "llvm/CodeGen/CostTable.h"
2020
#include "llvm/CodeGen/TargetLowering.h"
21+
#include "llvm/IR/DerivedTypes.h"
2122
#include "llvm/IR/IntrinsicInst.h"
2223
#include "llvm/IR/Intrinsics.h"
2324
#include "llvm/IR/IntrinsicsAArch64.h"
@@ -3616,17 +3617,26 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
36163617
// When SVE is available, we get:
36173618
// smulh + lsr + add/sub + asr + add/sub.
36183619
if (Ty->isScalableTy() && ST->hasSVE())
3619-
return 2 * MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
3620+
return MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
36203621
return 2 * MulCost + AddCost /*uzp2 cost*/ + AsrCost + UsraCost;
36213622
}
36223623
}
36233624
}
36243625
if (Op2Info.isConstant() && !Op2Info.isUniform() &&
36253626
LT.second.isFixedLengthVector()) {
3626-
auto VT = TLI->getValueType(DL, Ty);
3627-
return VT.getVectorNumElements() *
3628-
getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind,
3629-
Op1Info.getNoProps(), Op2Info.getNoProps());
3627+
// FIXME: When the constant vector is non-uniform, this may result in
3628+
// loading the vector from constant pool or in some cases, may also result
3629+
// in scalarization. For now, we are approximating this with the
3630+
// scalarization cost.
3631+
auto ExtractCost = 2 * getVectorInstrCost(Instruction::ExtractElement, Ty,
3632+
CostKind, -1, nullptr, nullptr);
3633+
auto InsertCost = getVectorInstrCost(Instruction::InsertElement, Ty,
3634+
CostKind, -1, nullptr, nullptr);
3635+
unsigned NElts = cast<FixedVectorType>(Ty)->getNumElements();
3636+
return ExtractCost + InsertCost +
3637+
NElts * getArithmeticInstrCost(Opcode, Ty->getScalarType(),
3638+
CostKind, Op1Info.getNoProps(),
3639+
Op2Info.getNoProps());
36303640
}
36313641
[[fallthrough]];
36323642
case ISD::UDIV:

0 commit comments

Comments
 (0)