Skip to content

Commit 1c0ccf3

Browse files
artagnonkrishna2803
authored andcommitted
[CostModel/RISCV] Fix costs of vector [l](lrint|lround) (llvm#146058)
Take the actual instruction cost into account, and don't fallthrough to code that doesn't apply to [l]lrint. Also strip invalid costs for [b]f16, as a companion to llvm#146507, and unify it with [l]lround costs as a companion to llvm#147713.
1 parent c01de9a commit 1c0ccf3

File tree

2 files changed

+426
-245
lines changed

2 files changed

+426
-245
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,9 +1191,6 @@ static const CostTblEntry VectorIntrinsicCostTable[]{
11911191
{Intrinsic::roundeven, MVT::f64, 9},
11921192
{Intrinsic::rint, MVT::f32, 7},
11931193
{Intrinsic::rint, MVT::f64, 7},
1194-
{Intrinsic::lrint, MVT::i32, 1},
1195-
{Intrinsic::lrint, MVT::i64, 1},
1196-
{Intrinsic::llrint, MVT::i64, 1},
11971194
{Intrinsic::nearbyint, MVT::f32, 9},
11981195
{Intrinsic::nearbyint, MVT::f64, 9},
11991196
{Intrinsic::bswap, MVT::i16, 3},
@@ -1251,11 +1248,43 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
12511248
switch (ICA.getID()) {
12521249
case Intrinsic::lrint:
12531250
case Intrinsic::llrint:
1254-
// We can't currently lower half or bfloat vector lrint/llrint.
1255-
if (auto *VecTy = dyn_cast<VectorType>(ICA.getArgTypes()[0]);
1256-
VecTy && VecTy->getElementType()->is16bitFPTy())
1257-
return InstructionCost::getInvalid();
1258-
[[fallthrough]];
1251+
case Intrinsic::lround:
1252+
case Intrinsic::llround: {
1253+
auto LT = getTypeLegalizationCost(RetTy);
1254+
Type *SrcTy = ICA.getArgTypes().front();
1255+
auto SrcLT = getTypeLegalizationCost(SrcTy);
1256+
if (ST->hasVInstructions() && LT.second.isVector()) {
1257+
ArrayRef<unsigned> Ops;
1258+
unsigned SrcEltSz = DL.getTypeSizeInBits(SrcTy->getScalarType());
1259+
unsigned DstEltSz = DL.getTypeSizeInBits(RetTy->getScalarType());
1260+
if (LT.second.getVectorElementType() == MVT::bf16) {
1261+
if (!ST->hasVInstructionsBF16Minimal())
1262+
return InstructionCost::getInvalid();
1263+
if (DstEltSz == 32)
1264+
Ops = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFCVT_X_F_V};
1265+
else
1266+
Ops = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVT_X_F_V};
1267+
} else if (LT.second.getVectorElementType() == MVT::f16 &&
1268+
!ST->hasVInstructionsF16()) {
1269+
if (!ST->hasVInstructionsF16Minimal())
1270+
return InstructionCost::getInvalid();
1271+
if (DstEltSz == 32)
1272+
Ops = {RISCV::VFWCVT_F_F_V, RISCV::VFCVT_X_F_V};
1273+
else
1274+
Ops = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_X_F_V};
1275+
1276+
} else if (SrcEltSz > DstEltSz) {
1277+
Ops = {RISCV::VFNCVT_X_F_W};
1278+
} else if (SrcEltSz < DstEltSz) {
1279+
Ops = {RISCV::VFWCVT_X_F_V};
1280+
} else {
1281+
Ops = {RISCV::VFCVT_X_F_V};
1282+
}
1283+
return std::max(SrcLT.first, LT.first) *
1284+
getRISCVInstructionCost(Ops, LT.second, CostKind);
1285+
}
1286+
break;
1287+
}
12591288
case Intrinsic::ceil:
12601289
case Intrinsic::floor:
12611290
case Intrinsic::trunc:

0 commit comments

Comments
 (0)