Skip to content

Commit fc2f27c

Browse files
committed
[RISCV] Add support for RVV int<->fp & fp<->fp conversions
This patch adds support for the full range of vector int-to-float, float-to-int, and float-to-float conversions on legal types. Many conversions are supported natively in RVV so are lowered with patterns. These include conversions between (element) types of the same size, and those that are half/double the size of the input. When conversions take place between types that are less than half or more than double the size we must lower them using sequences of instructions which go via intermediate types. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D95447
1 parent 2393b03 commit fc2f27c

File tree

7 files changed

+4775
-0
lines changed

7 files changed

+4775
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
402402
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
403403
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
404404

405+
// RVV has native int->float & float->int conversions where the
406+
// element type sizes are within one power-of-two of each other. Any
407+
// wider distances between type sizes have to be lowered as sequences
408+
// which progressively narrow the gap in stages.
409+
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
410+
setOperationAction(ISD::UINT_TO_FP, VT, Custom);
411+
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
412+
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
413+
405414
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
406415
// nodes which truncate by one power of two at a time.
407416
setOperationAction(ISD::TRUNCATE, VT, Custom);
@@ -427,9 +436,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
427436
// Sets common operation actions on RVV floating-point vector types.
428437
const auto SetCommonVFPActions = [&](MVT VT) {
429438
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
439+
// RVV has native FP_ROUND & FP_EXTEND conversions where the element type
440+
// sizes are within one power-of-two of each other. Therefore conversions
441+
// between vXf16 and vXf64 must be lowered as sequences which convert via
442+
// vXf32.
443+
setOperationAction(ISD::FP_ROUND, VT, Custom);
444+
setOperationAction(ISD::FP_EXTEND, VT, Custom);
430445
// Custom-lower insert/extract operations to simplify patterns.
431446
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
432447
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
448+
// Expand various condition codes (explained above).
433449
for (auto CC : VFPCCToExpand)
434450
setCondCodeAction(CC, VT, Expand);
435451
};
@@ -771,6 +787,99 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
771787
DAG.getConstant(3, DL, VT));
772788
return DAG.getNode(ISD::MUL, DL, VT, VScale, Op.getOperand(0));
773789
}
790+
case ISD::FP_EXTEND: {
791+
// RVV can only do fp_extend to types double the size as the source. We
792+
// custom-lower f16->f64 extensions to two hops of ISD::FP_EXTEND, going
793+
// via f32.
794+
MVT VT = Op.getSimpleValueType();
795+
MVT SrcVT = Op.getOperand(0).getSimpleValueType();
796+
// We only need to close the gap between vXf16->vXf64.
797+
if (!VT.isVector() || VT.getVectorElementType() != MVT::f64 ||
798+
SrcVT.getVectorElementType() != MVT::f16)
799+
return Op;
800+
SDLoc DL(Op);
801+
MVT InterVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
802+
SDValue IntermediateRound =
803+
DAG.getFPExtendOrRound(Op.getOperand(0), DL, InterVT);
804+
return DAG.getFPExtendOrRound(IntermediateRound, DL, VT);
805+
}
806+
case ISD::FP_ROUND: {
807+
// RVV can only do fp_round to types half the size as the source. We
808+
// custom-lower f64->f16 rounds via RVV's round-to-odd float
809+
// conversion instruction.
810+
MVT VT = Op.getSimpleValueType();
811+
MVT SrcVT = Op.getOperand(0).getSimpleValueType();
812+
// We only need to close the gap between vXf64<->vXf16.
813+
if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 ||
814+
SrcVT.getVectorElementType() != MVT::f64)
815+
return Op;
816+
SDLoc DL(Op);
817+
MVT InterVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
818+
SDValue IntermediateRound =
819+
DAG.getNode(RISCVISD::VFNCVT_ROD, DL, InterVT, Op.getOperand(0));
820+
return DAG.getFPExtendOrRound(IntermediateRound, DL, VT);
821+
}
822+
case ISD::FP_TO_SINT:
823+
case ISD::FP_TO_UINT:
824+
case ISD::SINT_TO_FP:
825+
case ISD::UINT_TO_FP: {
826+
// RVV can only do fp<->int conversions to types half/double the size as
827+
// the source. We custom-lower any conversions that do two hops into
828+
// sequences.
829+
MVT VT = Op.getSimpleValueType();
830+
if (!VT.isVector())
831+
return Op;
832+
SDLoc DL(Op);
833+
SDValue Src = Op.getOperand(0);
834+
MVT EltVT = VT.getVectorElementType();
835+
MVT SrcEltVT = Src.getSimpleValueType().getVectorElementType();
836+
unsigned EltSize = EltVT.getSizeInBits();
837+
unsigned SrcEltSize = SrcEltVT.getSizeInBits();
838+
assert(isPowerOf2_32(EltSize) && isPowerOf2_32(SrcEltSize) &&
839+
"Unexpected vector element types");
840+
bool IsInt2FP = SrcEltVT.isInteger();
841+
// Widening conversions
842+
if (EltSize > SrcEltSize && (EltSize / SrcEltSize >= 4)) {
843+
if (IsInt2FP) {
844+
// Do a regular integer sign/zero extension then convert to float.
845+
MVT IVecVT = MVT::getVectorVT(MVT::getIntegerVT(EltVT.getSizeInBits()),
846+
VT.getVectorElementCount());
847+
unsigned ExtOpcode = Op.getOpcode() == ISD::UINT_TO_FP
848+
? ISD::ZERO_EXTEND
849+
: ISD::SIGN_EXTEND;
850+
SDValue Ext = DAG.getNode(ExtOpcode, DL, IVecVT, Src);
851+
return DAG.getNode(Op.getOpcode(), DL, VT, Ext);
852+
}
853+
// FP2Int
854+
assert(SrcEltVT == MVT::f16 && "Unexpected FP_TO_[US]INT lowering");
855+
// Do one doubling fp_extend then complete the operation by converting
856+
// to int.
857+
MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
858+
SDValue FExt = DAG.getFPExtendOrRound(Src, DL, InterimFVT);
859+
return DAG.getNode(Op.getOpcode(), DL, VT, FExt);
860+
}
861+
862+
// Narrowing conversions
863+
if (SrcEltSize > EltSize && (SrcEltSize / EltSize >= 4)) {
864+
if (IsInt2FP) {
865+
// One narrowing int_to_fp, then an fp_round.
866+
assert(EltVT == MVT::f16 && "Unexpected [US]_TO_FP lowering");
867+
MVT InterimFVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
868+
SDValue Int2FP = DAG.getNode(Op.getOpcode(), DL, InterimFVT, Src);
869+
return DAG.getFPExtendOrRound(Int2FP, DL, VT);
870+
}
871+
// FP2Int
872+
// One narrowing fp_to_int, then truncate the integer. If the float isn't
873+
// representable by the integer, the result is poison.
874+
MVT IVecVT =
875+
MVT::getVectorVT(MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2),
876+
VT.getVectorElementCount());
877+
SDValue FP2Int = DAG.getNode(Op.getOpcode(), DL, IVecVT, Src);
878+
return DAG.getNode(ISD::TRUNCATE, DL, VT, FP2Int);
879+
}
880+
881+
return Op;
882+
}
774883
}
775884
}
776885

@@ -4012,6 +4121,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
40124121
NODE_NAME_CASE(VSLIDEUP)
40134122
NODE_NAME_CASE(VSLIDEDOWN)
40144123
NODE_NAME_CASE(VID)
4124+
NODE_NAME_CASE(VFNCVT_ROD)
40154125
}
40164126
// clang-format on
40174127
return nullptr;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ enum NodeType : unsigned {
106106
VSLIDEDOWN,
107107
// Matches the semantics of the unmasked vid.v instruction.
108108
VID,
109+
// Matches the semantics of the vfcnvt.rod function (Convert double-width
110+
// float to single-width float, rounding towards odd). Takes a double-width
111+
// float vector and produces a single-width float vector.
112+
VFNCVT_ROD,
109113
};
110114
} // namespace RISCVISD
111115

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,64 @@ multiclass VPatExtendSDNode_V<list<SDNode> ops, string inst_name, string suffix,
291291
}
292292
}
293293

294+
multiclass VPatConvertI2FPSDNode_V<SDNode vop, string instruction_name> {
295+
foreach fvti = AllFloatVectors in {
296+
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
297+
def : Pat<(fvti.Vector (vop (ivti.Vector ivti.RegClass:$rs1))),
298+
(!cast<Instruction>(instruction_name#"_"#fvti.LMul.MX)
299+
ivti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
300+
}
301+
}
302+
303+
multiclass VPatConvertFP2ISDNode_V<SDNode vop, string instruction_name> {
304+
foreach fvti = AllFloatVectors in {
305+
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
306+
def : Pat<(ivti.Vector (vop (fvti.Vector fvti.RegClass:$rs1))),
307+
(!cast<Instruction>(instruction_name#"_"#ivti.LMul.MX)
308+
fvti.RegClass:$rs1, ivti.AVL, ivti.SEW)>;
309+
}
310+
}
311+
312+
multiclass VPatWConvertI2FPSDNode_V<SDNode vop, string instruction_name> {
313+
foreach vtiToWti = AllWidenableIntToFloatVectors in {
314+
defvar ivti = vtiToWti.Vti;
315+
defvar fwti = vtiToWti.Wti;
316+
def : Pat<(fwti.Vector (vop (ivti.Vector ivti.RegClass:$rs1))),
317+
(!cast<Instruction>(instruction_name#"_"#ivti.LMul.MX)
318+
ivti.RegClass:$rs1, ivti.AVL, ivti.SEW)>;
319+
}
320+
}
321+
322+
multiclass VPatWConvertFP2ISDNode_V<SDNode vop, string instruction_name> {
323+
foreach fvtiToFWti = AllWidenableFloatVectors in {
324+
defvar fvti = fvtiToFWti.Vti;
325+
defvar iwti = GetIntVTypeInfo<fvtiToFWti.Wti>.Vti;
326+
def : Pat<(iwti.Vector (vop (fvti.Vector fvti.RegClass:$rs1))),
327+
(!cast<Instruction>(instruction_name#"_"#fvti.LMul.MX)
328+
fvti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
329+
}
330+
}
331+
332+
multiclass VPatNConvertI2FPSDNode_V<SDNode vop, string instruction_name> {
333+
foreach fvtiToFWti = AllWidenableFloatVectors in {
334+
defvar fvti = fvtiToFWti.Vti;
335+
defvar iwti = GetIntVTypeInfo<fvtiToFWti.Wti>.Vti;
336+
def : Pat<(fvti.Vector (vop (iwti.Vector iwti.RegClass:$rs1))),
337+
(!cast<Instruction>(instruction_name#"_"#fvti.LMul.MX)
338+
iwti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
339+
}
340+
}
341+
342+
multiclass VPatNConvertFP2ISDNode_V<SDNode vop, string instruction_name> {
343+
foreach vtiToWti = AllWidenableIntToFloatVectors in {
344+
defvar vti = vtiToWti.Vti;
345+
defvar fwti = vtiToWti.Wti;
346+
def : Pat<(vti.Vector (vop (fwti.Vector fwti.RegClass:$rs1))),
347+
(!cast<Instruction>(instruction_name#"_"#vti.LMul.MX)
348+
fwti.RegClass:$rs1, vti.AVL, vti.SEW)>;
349+
}
350+
}
351+
294352
//===----------------------------------------------------------------------===//
295353
// Patterns.
296354
//===----------------------------------------------------------------------===//
@@ -440,6 +498,10 @@ foreach mti = AllMasks in {
440498

441499
} // Predicates = [HasStdExtV]
442500

501+
def riscv_fncvt_rod
502+
: SDNode<"RISCVISD::VFNCVT_ROD",
503+
SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>, []>;
504+
443505
let Predicates = [HasStdExtV, HasStdExtF] in {
444506

445507
// 14.2. Vector Single-Width Floating-Point Add/Subtract Instructions
@@ -489,6 +551,43 @@ foreach fvti = AllFloatVectors in {
489551
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
490552
fvti.RegClass:$rs2, 0, VMV0:$vm, fvti.AVL, fvti.SEW)>;
491553
}
554+
555+
// 14.15. Vector Single-Width Floating-Point/Integer Type-Convert Instructions
556+
defm "" : VPatConvertFP2ISDNode_V<fp_to_sint, "PseudoVFCVT_RTZ_X_F_V">;
557+
defm "" : VPatConvertFP2ISDNode_V<fp_to_uint, "PseudoVFCVT_RTZ_XU_F_V">;
558+
defm "" : VPatConvertI2FPSDNode_V<sint_to_fp, "PseudoVFCVT_F_X_V">;
559+
defm "" : VPatConvertI2FPSDNode_V<uint_to_fp, "PseudoVFCVT_F_XU_V">;
560+
561+
// 14.16. Widening Floating-Point/Integer Type-Convert Instructions
562+
defm "" : VPatWConvertFP2ISDNode_V<fp_to_sint, "PseudoVFWCVT_RTZ_X_F_V">;
563+
defm "" : VPatWConvertFP2ISDNode_V<fp_to_uint, "PseudoVFWCVT_RTZ_XU_F_V">;
564+
defm "" : VPatWConvertI2FPSDNode_V<sint_to_fp, "PseudoVFWCVT_F_X_V">;
565+
defm "" : VPatWConvertI2FPSDNode_V<uint_to_fp, "PseudoVFWCVT_F_XU_V">;
566+
foreach fvtiToFWti = AllWidenableFloatVectors in {
567+
defvar fvti = fvtiToFWti.Vti;
568+
defvar fwti = fvtiToFWti.Wti;
569+
def : Pat<(fwti.Vector (fpextend (fvti.Vector fvti.RegClass:$rs1))),
570+
(!cast<Instruction>("PseudoVFWCVT_F_F_V_"#fvti.LMul.MX)
571+
fvti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
572+
}
573+
574+
// 14.17. Narrowing Floating-Point/Integer Type-Convert Instructions
575+
defm "" : VPatNConvertFP2ISDNode_V<fp_to_sint, "PseudoVFNCVT_RTZ_X_F_W">;
576+
defm "" : VPatNConvertFP2ISDNode_V<fp_to_uint, "PseudoVFNCVT_RTZ_XU_F_W">;
577+
defm "" : VPatNConvertI2FPSDNode_V<sint_to_fp, "PseudoVFNCVT_F_X_W">;
578+
defm "" : VPatNConvertI2FPSDNode_V<uint_to_fp, "PseudoVFNCVT_F_XU_W">;
579+
foreach fvtiToFWti = AllWidenableFloatVectors in {
580+
defvar fvti = fvtiToFWti.Vti;
581+
defvar fwti = fvtiToFWti.Wti;
582+
def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
583+
(!cast<Instruction>("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX)
584+
fwti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
585+
586+
def : Pat<(fvti.Vector (riscv_fncvt_rod (fwti.Vector fwti.RegClass:$rs1))),
587+
(!cast<Instruction>("PseudoVFNCVT_ROD_F_F_W_"#fvti.LMul.MX)
588+
fwti.RegClass:$rs1, fvti.AVL, fvti.SEW)>;
589+
}
590+
492591
} // Predicates = [HasStdExtV, HasStdExtF]
493592

494593
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)