Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 84 additions & 26 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
DstLT.second.getSizeInBits()))
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

// The split cost is handled by the base getCastInstrCost
assert((SrcLT.first == 1) && (DstLT.first == 1) && "Illegal type");

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");

Expand Down Expand Up @@ -1118,34 +1121,89 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return Cost;
}
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
// For fp vector to mask, we use:
// vfncvt.rtz.x.f.w v9, v8
// vand.vi v8, v9, 1
// vmsne.vi v0, v8, 0
if (Dst->getScalarSizeInBits() == 1)
return 3;

if (std::abs(PowDiff) <= 1)
return 1;
case ISD::FP_TO_UINT: {
unsigned IsSigned = ISD == ISD::FP_TO_SINT;
unsigned FCVT = IsSigned ? RISCV::VFCVT_RTZ_X_F_V : RISCV::VFCVT_RTZ_XU_F_V;
unsigned FWCVT =
IsSigned ? RISCV::VFWCVT_RTZ_X_F_V : RISCV::VFWCVT_RTZ_XU_F_V;
unsigned FNCVT =
IsSigned ? RISCV::VFNCVT_RTZ_X_F_W : RISCV::VFNCVT_RTZ_XU_F_W;
unsigned SrcEltSize = Src->getScalarSizeInBits();
unsigned DstEltSize = Dst->getScalarSizeInBits();
InstructionCost Cost = 0;
if ((SrcEltSize == 16) &&
(!ST->hasVInstructionsF16() || ((DstEltSize / 2) > SrcEltSize))) {
// If the target only supports zvfhmin or it is fp16-to-i64 conversion
// pre-widening to f32 and then convert f32 to integer
VectorType *VecF32Ty =
VectorType::get(Type::getFloatTy(Dst->getContext()),
cast<VectorType>(Dst)->getElementCount());
std::pair<InstructionCost, MVT> VecF32LT =
getTypeLegalizationCost(VecF32Ty);
Cost +=
VecF32LT.first * getRISCVInstructionCost(RISCV::VFWCVT_F_F_V,
VecF32LT.second, CostKind);
Cost += getCastInstrCost(Opcode, Dst, VecF32Ty, CCH, CostKind, I);
return Cost;
}
if (DstEltSize == SrcEltSize)
Cost += getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
else if (DstEltSize > SrcEltSize)
Cost += getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
else { // (SrcEltSize > DstEltSize)
// First do a narrowing conversion to an integer half the size, then
// truncate if needed.
MVT ElementVT = MVT::getIntegerVT(SrcEltSize / 2);
MVT VecVT = DstLT.second.changeVectorElementType(ElementVT);
Cost += getRISCVInstructionCost(FNCVT, VecVT, CostKind);
if ((SrcEltSize / 2) > DstEltSize) {
Type *VecTy = EVT(VecVT).getTypeForEVT(Dst->getContext());
Cost +=
getCastInstrCost(Instruction::Trunc, Dst, VecTy, CCH, CostKind, I);
}
}
return Cost;
}
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP: {
unsigned IsSigned = ISD == ISD::SINT_TO_FP;
unsigned FCVT = IsSigned ? RISCV::VFCVT_F_X_V : RISCV::VFCVT_F_XU_V;
unsigned FWCVT = IsSigned ? RISCV::VFWCVT_F_X_V : RISCV::VFWCVT_F_XU_V;
unsigned FNCVT = IsSigned ? RISCV::VFNCVT_F_X_W : RISCV::VFNCVT_F_XU_W;
unsigned SrcEltSize = Src->getScalarSizeInBits();
unsigned DstEltSize = Dst->getScalarSizeInBits();

// Counts of narrow/widen instructions.
return std::abs(PowDiff);
InstructionCost Cost = 0;
if ((DstEltSize == 16) &&
(!ST->hasVInstructionsF16() || ((SrcEltSize / 2) > DstEltSize))) {
// If the target only supports zvfhmin or it is i64-to-fp16 conversion
// it is converted to f32 and then converted to f16
VectorType *VecF32Ty =
VectorType::get(Type::getFloatTy(Dst->getContext()),
cast<VectorType>(Dst)->getElementCount());
std::pair<InstructionCost, MVT> VecF32LT =
getTypeLegalizationCost(VecF32Ty);
Cost += getCastInstrCost(Opcode, VecF32Ty, Src, CCH, CostKind, I);
Cost += VecF32LT.first * getRISCVInstructionCost(RISCV::VFNCVT_F_F_W,
DstLT.second, CostKind);
return Cost;
}

case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
// For mask vector to fp, we should use the following instructions:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
// vfcvt.f.x.v v8, v8
if (Src->getScalarSizeInBits() == 1)
return 3;

if (std::abs(PowDiff) <= 1)
return 1;
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
// so it only need two conversion.
return 2;
if (DstEltSize == SrcEltSize)
Cost += getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
else if (DstEltSize > SrcEltSize) {
if ((DstEltSize / 2) > SrcEltSize) {
VectorType *VecTy =
VectorType::get(IntegerType::get(Dst->getContext(), DstEltSize / 2),
cast<VectorType>(Dst)->getElementCount());
unsigned Op = IsSigned ? Instruction::SExt : Instruction::ZExt;
Cost += getCastInstrCost(Op, VecTy, Src, CCH, CostKind, I);
}
Cost += getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
} else
Cost += getRISCVInstructionCost(FNCVT, DstLT.second, CostKind);
return Cost;
}
}
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}
Expand Down
Loading