@@ -830,7 +830,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
830830 // We have some custom DAG combine patterns for these nodes
831831 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
832832 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
833- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
833+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
834834
835835 // setcc for f16x2 and bf16x2 needs special handling to prevent
836836 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5693,6 +5693,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
56935693 return SDValue ();
56945694}
56955695
5696+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5697+ TargetLowering::DAGCombinerInfo &DCI) {
5698+ SDLoc DL (N);
5699+ SDValue Op = N->getOperand (0 );
5700+ SDValue Trunc = N->getOperand (1 );
5701+ EVT NarrowVT = N->getValueType (0 );
5702+ EVT WideVT = Op.getValueType ();
5703+
5704+ // v2[b]f16 = fp_round (v2f32 A)
5705+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5706+ // ([b]f16 = fp_round (extractelt A, 1)))
5707+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5708+ WideVT == MVT::v2f32) {
5709+ SDValue F32Op0, F32Op1;
5710+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5711+ F32Op0 = Op.getOperand (0 );
5712+ F32Op1 = Op.getOperand (1 );
5713+ } else {
5714+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5715+ DCI.DAG .getIntPtrConstant (0 , DL));
5716+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5717+ DCI.DAG .getIntPtrConstant (1 , DL));
5718+ }
5719+ return DCI.DAG .getBuildVector (
5720+ NarrowVT, DL,
5721+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5722+ Trunc),
5723+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5724+ Trunc)});
5725+ }
5726+
5727+ return SDValue ();
5728+ }
5729+
56965730SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
56975731 DAGCombinerInfo &DCI) const {
56985732 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5729,6 +5763,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57295763 return PerformBUILD_VECTORCombine (N, DCI);
57305764 case ISD::ADDRSPACECAST:
57315765 return combineADDRSPACECAST (N, DCI);
5766+ case ISD::FP_ROUND:
5767+ return PerformFP_ROUNDCombine (N, DCI);
57325768 }
57335769 return SDValue ();
57345770}
0 commit comments