Skip to content

Commit 1a9569c

Browse files
authored
[RISCV][TTI] Avoid an infinite recursion issue in getCastInstrCost (#110164)
Calling into BasicTTI is not always safe. In particular, BasicTTI does not have a full legalization implementation (vector widening is missing), and falls back on scalarization. The problem is that scalarization for <N x i1> vectors is cost in terms of the cast API and we can end up in an infinite recursive cycle. The "right" fix for this would be teach BasicTTI how to model the full legalization state machine, but several attempts at doing so have resulted in dead ends or undesirable cost changes for targets I don't understand. This patch instead papers over the issue by avoiding the call to the base class when dealing with an i1 source or dest. This doesn't necessarily produce correct costs, but it should at least return something semi-sensible and not crash. Fixes #108708
1 parent 296901f commit 1a9569c

File tree

3 files changed

+158
-115
lines changed

3 files changed

+158
-115
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,9 +1163,47 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
11631163
Dst->getScalarSizeInBits() > ST->getELen())
11641164
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
11651165

1166+
int ISD = TLI->InstructionOpcodeToISD(Opcode);
1167+
assert(ISD && "Invalid opcode");
11661168
std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
11671169
std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);
11681170

1171+
// Handle i1 source and dest cases *before* calling logic in BasicTTI.
1172+
// The shared implementation doesn't model vector widening during legalization
1173+
// and instead assumes scalarization. In order to scalarize an <N x i1>
1174+
// vector, we need to extend/trunc to/from i8. If we don't special case
1175+
// this, we can get an infinite recursion cycle.
1176+
switch (ISD) {
1177+
default:
1178+
break;
1179+
case ISD::SIGN_EXTEND:
1180+
case ISD::ZERO_EXTEND:
1181+
if (Src->getScalarSizeInBits() == 1) {
1182+
// We do not use vsext/vzext to extend from mask vector.
1183+
// Instead we use the following instructions to extend from mask vector:
1184+
// vmv.v.i v8, 0
1185+
// vmerge.vim v8, v8, -1, v0
1186+
return DstLT.first *
1187+
getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
1188+
DstLT.second, CostKind) +
1189+
DstLT.first - 1;
1190+
}
1191+
break;
1192+
case ISD::TRUNCATE:
1193+
if (Dst->getScalarSizeInBits() == 1) {
1194+
// We do not use several vncvt to truncate to mask vector. So we could
1195+
// not use PowDiff to calculate it.
1196+
// Instead we use the following instructions to truncate to mask vector:
1197+
// vand.vi v8, v8, 1
1198+
// vmsne.vi v0, v8, 0
1199+
return SrcLT.first *
1200+
getRISCVInstructionCost({RISCV::VAND_VI, RISCV::VMSNE_VI},
1201+
SrcLT.second, CostKind) +
1202+
SrcLT.first - 1;
1203+
}
1204+
break;
1205+
};
1206+
11691207
// Our actual lowering for the case where a wider legal type is available
11701208
// uses promotion to the wider type. This is reflected in the result of
11711209
// getTypeLegalizationCost, but BasicTTI assumes the widened cases are
@@ -1181,22 +1219,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
11811219
// The split cost is handled by the base getCastInstrCost
11821220
assert((SrcLT.first == 1) && (DstLT.first == 1) && "Illegal type");
11831221

1184-
int ISD = TLI->InstructionOpcodeToISD(Opcode);
1185-
assert(ISD && "Invalid opcode");
1186-
11871222
int PowDiff = (int)Log2_32(DstLT.second.getScalarSizeInBits()) -
11881223
(int)Log2_32(SrcLT.second.getScalarSizeInBits());
11891224
switch (ISD) {
11901225
case ISD::SIGN_EXTEND:
11911226
case ISD::ZERO_EXTEND: {
1192-
if (Src->getScalarSizeInBits() == 1) {
1193-
// We do not use vsext/vzext to extend from mask vector.
1194-
// Instead we use the following instructions to extend from mask vector:
1195-
// vmv.v.i v8, 0
1196-
// vmerge.vim v8, v8, -1, v0
1197-
return getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
1198-
DstLT.second, CostKind);
1199-
}
12001227
if ((PowDiff < 1) || (PowDiff > 3))
12011228
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
12021229
unsigned SExtOp[] = {RISCV::VSEXT_VF2, RISCV::VSEXT_VF4, RISCV::VSEXT_VF8};
@@ -1206,16 +1233,6 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
12061233
return getRISCVInstructionCost(Op, DstLT.second, CostKind);
12071234
}
12081235
case ISD::TRUNCATE:
1209-
if (Dst->getScalarSizeInBits() == 1) {
1210-
// We do not use several vncvt to truncate to mask vector. So we could
1211-
// not use PowDiff to calculate it.
1212-
// Instead we use the following instructions to truncate to mask vector:
1213-
// vand.vi v8, v8, 1
1214-
// vmsne.vi v0, v8, 0
1215-
return getRISCVInstructionCost({RISCV::VAND_VI, RISCV::VMSNE_VI},
1216-
SrcLT.second, CostKind);
1217-
}
1218-
[[fallthrough]];
12191236
case ISD::FP_EXTEND:
12201237
case ISD::FP_ROUND: {
12211238
// Counts of narrow/widen instructions.

0 commit comments

Comments
 (0)