Skip to content

Commit 2770732

Browse files
committed
[RISCV][TTI] Fix a costing mistake for truncate/fp_round with LMUL>m1
For a narrowing operation, the work performed scales with the source LMUL not the destination LMUL. A side effect of the code sharing with FP_EXTEND was that we used the wrong LMUL when costing the inserted narrowing operations. For casts which start with a high LMUL operation, this change makes the cost significantly more expensive.
1 parent d6081bf commit 2770732

File tree

2 files changed

+181
-172
lines changed

2 files changed

+181
-172
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,24 +1077,33 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
10771077
SrcLT.second, CostKind);
10781078
}
10791079
[[fallthrough]];
1080-
case ISD::FP_EXTEND:
10811080
case ISD::FP_ROUND: {
1082-
// Counts of narrow/widen instructions.
1081+
// Counts of narrowing instructions.
10831082
unsigned SrcEltSize = Src->getScalarSizeInBits();
10841083
unsigned DstEltSize = Dst->getScalarSizeInBits();
10851084

1086-
unsigned Op = (ISD == ISD::TRUNCATE) ? RISCV::VNSRL_WI
1087-
: (ISD == ISD::FP_EXTEND) ? RISCV::VFWCVT_F_F_V
1088-
: RISCV::VFNCVT_F_F_W;
1085+
const unsigned Op =
1086+
(ISD == ISD::TRUNCATE) ? RISCV::VNSRL_WI : RISCV::VFNCVT_F_F_W;
10891087
InstructionCost Cost = 0;
1090-
for (; SrcEltSize != DstEltSize;) {
1088+
for (; SrcEltSize != DstEltSize; SrcEltSize = SrcEltSize >> 1) {
10911089
MVT ElementMVT = (ISD == ISD::TRUNCATE)
1092-
? MVT::getIntegerVT(DstEltSize)
1093-
: MVT::getFloatingPointVT(DstEltSize);
1090+
? MVT::getIntegerVT(SrcEltSize)
1091+
: MVT::getFloatingPointVT(SrcEltSize);
1092+
MVT SrcMVT = SrcLT.second.changeVectorElementType(ElementMVT);
1093+
Cost += getRISCVInstructionCost(Op, SrcMVT, CostKind);
1094+
}
1095+
return Cost;
1096+
}
1097+
case ISD::FP_EXTEND: {
1098+
// Counts of widening instructions.
1099+
unsigned SrcEltSize = Src->getScalarSizeInBits();
1100+
unsigned DstEltSize = Dst->getScalarSizeInBits();
1101+
1102+
InstructionCost Cost = 0;
1103+
for (; SrcEltSize != DstEltSize; DstEltSize = DstEltSize >> 1) {
1104+
MVT ElementMVT = MVT::getFloatingPointVT(DstEltSize);
10941105
MVT DstMVT = DstLT.second.changeVectorElementType(ElementMVT);
1095-
DstEltSize =
1096-
(DstEltSize > SrcEltSize) ? DstEltSize >> 1 : DstEltSize << 1;
1097-
Cost += getRISCVInstructionCost(Op, DstMVT, CostKind);
1106+
Cost += getRISCVInstructionCost(RISCV::VFWCVT_F_F_V, DstMVT, CostKind);
10981107
}
10991108
return Cost;
11001109
}

0 commit comments

Comments
 (0)