@@ -402,6 +402,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
402
402
setOperationAction (ISD::SIGN_EXTEND, VT, Custom);
403
403
setOperationAction (ISD::ZERO_EXTEND, VT, Custom);
404
404
405
+ // RVV has native int->float & float->int conversions where the
406
+ // element type sizes are within one power-of-two of each other. Any
407
+ // wider distances between type sizes have to be lowered as sequences
408
+ // which progressively narrow the gap in stages.
409
+ setOperationAction (ISD::SINT_TO_FP, VT, Custom);
410
+ setOperationAction (ISD::UINT_TO_FP, VT, Custom);
411
+ setOperationAction (ISD::FP_TO_SINT, VT, Custom);
412
+ setOperationAction (ISD::FP_TO_UINT, VT, Custom);
413
+
405
414
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
406
415
// nodes which truncate by one power of two at a time.
407
416
setOperationAction (ISD::TRUNCATE, VT, Custom);
@@ -427,9 +436,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
427
436
// Sets common operation actions on RVV floating-point vector types.
428
437
const auto SetCommonVFPActions = [&](MVT VT) {
429
438
setOperationAction (ISD::SPLAT_VECTOR, VT, Legal);
439
+ // RVV has native FP_ROUND & FP_EXTEND conversions where the element type
440
+ // sizes are within one power-of-two of each other. Therefore conversions
441
+ // between vXf16 and vXf64 must be lowered as sequences which convert via
442
+ // vXf32.
443
+ setOperationAction (ISD::FP_ROUND, VT, Custom);
444
+ setOperationAction (ISD::FP_EXTEND, VT, Custom);
430
445
// Custom-lower insert/extract operations to simplify patterns.
431
446
setOperationAction (ISD::INSERT_VECTOR_ELT, VT, Custom);
432
447
setOperationAction (ISD::EXTRACT_VECTOR_ELT, VT, Custom);
448
+ // Expand various condition codes (explained above).
433
449
for (auto CC : VFPCCToExpand)
434
450
setCondCodeAction (CC, VT, Expand);
435
451
};
@@ -771,6 +787,99 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
771
787
DAG.getConstant (3 , DL, VT));
772
788
return DAG.getNode (ISD::MUL, DL, VT, VScale, Op.getOperand (0 ));
773
789
}
790
+ case ISD::FP_EXTEND: {
791
+ // RVV can only do fp_extend to types double the size as the source. We
792
+ // custom-lower f16->f64 extensions to two hops of ISD::FP_EXTEND, going
793
+ // via f32.
794
+ MVT VT = Op.getSimpleValueType ();
795
+ MVT SrcVT = Op.getOperand (0 ).getSimpleValueType ();
796
+ // We only need to close the gap between vXf16->vXf64.
797
+ if (!VT.isVector () || VT.getVectorElementType () != MVT::f64 ||
798
+ SrcVT.getVectorElementType () != MVT::f16 )
799
+ return Op;
800
+ SDLoc DL (Op);
801
+ MVT InterVT = MVT::getVectorVT (MVT::f32 , VT.getVectorElementCount ());
802
+ SDValue IntermediateRound =
803
+ DAG.getFPExtendOrRound (Op.getOperand (0 ), DL, InterVT);
804
+ return DAG.getFPExtendOrRound (IntermediateRound, DL, VT);
805
+ }
806
+ case ISD::FP_ROUND: {
807
+ // RVV can only do fp_round to types half the size as the source. We
808
+ // custom-lower f64->f16 rounds via RVV's round-to-odd float
809
+ // conversion instruction.
810
+ MVT VT = Op.getSimpleValueType ();
811
+ MVT SrcVT = Op.getOperand (0 ).getSimpleValueType ();
812
+ // We only need to close the gap between vXf64<->vXf16.
813
+ if (!VT.isVector () || VT.getVectorElementType () != MVT::f16 ||
814
+ SrcVT.getVectorElementType () != MVT::f64 )
815
+ return Op;
816
+ SDLoc DL (Op);
817
+ MVT InterVT = MVT::getVectorVT (MVT::f32 , VT.getVectorElementCount ());
818
+ SDValue IntermediateRound =
819
+ DAG.getNode (RISCVISD::VFNCVT_ROD, DL, InterVT, Op.getOperand (0 ));
820
+ return DAG.getFPExtendOrRound (IntermediateRound, DL, VT);
821
+ }
822
+ case ISD::FP_TO_SINT:
823
+ case ISD::FP_TO_UINT:
824
+ case ISD::SINT_TO_FP:
825
+ case ISD::UINT_TO_FP: {
826
+ // RVV can only do fp<->int conversions to types half/double the size as
827
+ // the source. We custom-lower any conversions that do two hops into
828
+ // sequences.
829
+ MVT VT = Op.getSimpleValueType ();
830
+ if (!VT.isVector ())
831
+ return Op;
832
+ SDLoc DL (Op);
833
+ SDValue Src = Op.getOperand (0 );
834
+ MVT EltVT = VT.getVectorElementType ();
835
+ MVT SrcEltVT = Src.getSimpleValueType ().getVectorElementType ();
836
+ unsigned EltSize = EltVT.getSizeInBits ();
837
+ unsigned SrcEltSize = SrcEltVT.getSizeInBits ();
838
+ assert (isPowerOf2_32 (EltSize) && isPowerOf2_32 (SrcEltSize) &&
839
+ " Unexpected vector element types" );
840
+ bool IsInt2FP = SrcEltVT.isInteger ();
841
+ // Widening conversions
842
+ if (EltSize > SrcEltSize && (EltSize / SrcEltSize >= 4 )) {
843
+ if (IsInt2FP) {
844
+ // Do a regular integer sign/zero extension then convert to float.
845
+ MVT IVecVT = MVT::getVectorVT (MVT::getIntegerVT (EltVT.getSizeInBits ()),
846
+ VT.getVectorElementCount ());
847
+ unsigned ExtOpcode = Op.getOpcode () == ISD::UINT_TO_FP
848
+ ? ISD::ZERO_EXTEND
849
+ : ISD::SIGN_EXTEND;
850
+ SDValue Ext = DAG.getNode (ExtOpcode, DL, IVecVT, Src);
851
+ return DAG.getNode (Op.getOpcode (), DL, VT, Ext);
852
+ }
853
+ // FP2Int
854
+ assert (SrcEltVT == MVT::f16 && " Unexpected FP_TO_[US]INT lowering" );
855
+ // Do one doubling fp_extend then complete the operation by converting
856
+ // to int.
857
+ MVT InterimFVT = MVT::getVectorVT (MVT::f32 , VT.getVectorElementCount ());
858
+ SDValue FExt = DAG.getFPExtendOrRound (Src, DL, InterimFVT);
859
+ return DAG.getNode (Op.getOpcode (), DL, VT, FExt);
860
+ }
861
+
862
+ // Narrowing conversions
863
+ if (SrcEltSize > EltSize && (SrcEltSize / EltSize >= 4 )) {
864
+ if (IsInt2FP) {
865
+ // One narrowing int_to_fp, then an fp_round.
866
+ assert (EltVT == MVT::f16 && " Unexpected [US]_TO_FP lowering" );
867
+ MVT InterimFVT = MVT::getVectorVT (MVT::f32 , VT.getVectorElementCount ());
868
+ SDValue Int2FP = DAG.getNode (Op.getOpcode (), DL, InterimFVT, Src);
869
+ return DAG.getFPExtendOrRound (Int2FP, DL, VT);
870
+ }
871
+ // FP2Int
872
+ // One narrowing fp_to_int, then truncate the integer. If the float isn't
873
+ // representable by the integer, the result is poison.
874
+ MVT IVecVT =
875
+ MVT::getVectorVT (MVT::getIntegerVT (SrcEltVT.getSizeInBits () / 2 ),
876
+ VT.getVectorElementCount ());
877
+ SDValue FP2Int = DAG.getNode (Op.getOpcode (), DL, IVecVT, Src);
878
+ return DAG.getNode (ISD::TRUNCATE, DL, VT, FP2Int);
879
+ }
880
+
881
+ return Op;
882
+ }
774
883
}
775
884
}
776
885
@@ -4012,6 +4121,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
4012
4121
NODE_NAME_CASE (VSLIDEUP)
4013
4122
NODE_NAME_CASE (VSLIDEDOWN)
4014
4123
NODE_NAME_CASE (VID)
4124
+ NODE_NAME_CASE (VFNCVT_ROD)
4015
4125
}
4016
4126
// clang-format on
4017
4127
return nullptr ;
0 commit comments