@@ -757,7 +757,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
757757 setOperationAction(ISD::FABS, MVT::v2f16, Legal);
758758
759759 // Can do this in one BFI plus a constant materialize.
760- setOperationAction(ISD::FCOPYSIGN, {MVT::v2f16, MVT::v2bf16}, Custom);
760+ setOperationAction(ISD::FCOPYSIGN,
761+ {MVT::v2f16, MVT::v2bf16, MVT::v4f16, MVT::v4bf16},
762+ Custom);
761763
762764 setOperationAction({ISD::FMAXNUM, ISD::FMINNUM}, MVT::f16, Custom);
763765 setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
@@ -5936,10 +5938,11 @@ SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
59365938 SelectionDAG &DAG) const {
59375939 unsigned Opc = Op.getOpcode();
59385940 EVT VT = Op.getValueType();
5939- assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
5940- VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
5941- VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5942- VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16);
5941+ assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
5942+ VT == MVT::v4f32 || VT == MVT::v8i16 || VT == MVT::v8f16 ||
5943+ VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v8f32 ||
5944+ VT == MVT::v16f32 || VT == MVT::v32f32 || VT == MVT::v32i16 ||
5945+ VT == MVT::v32f16);
59435946
59445947 auto [Lo0, Hi0] = DAG.SplitVectorOperand(Op.getNode(), 0);
59455948 auto [Lo1, Hi1] = DAG.SplitVectorOperand(Op.getNode(), 1);
@@ -7122,18 +7125,17 @@ SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
71227125
71237126SDValue SITargetLowering::lowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const {
71247127 SDValue Mag = Op.getOperand(0);
7125- SDValue Sign = Op.getOperand(1);
7126-
71277128 EVT MagVT = Mag.getValueType();
7128- EVT SignVT = Sign.getValueType();
71297129
7130- assert(MagVT.isVector());
7130+ if (MagVT.getVectorNumElements() > 2)
7131+ return splitBinaryVectorOp(Op, DAG);
7132+
7133+ SDValue Sign = Op.getOperand(1);
7134+ EVT SignVT = Sign.getValueType();
71317135
71327136 if (MagVT == SignVT)
71337137 return Op;
71347138
7135- assert(MagVT.getVectorNumElements() == 2);
7136-
71377139 // fcopysign v2f16:mag, v2f32:sign ->
71387140 // fcopysign v2f16:mag, bitcast (trunc (bitcast sign to v2i32) to v2i16)
71397141
0 commit comments