@@ -826,7 +826,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
826
826
// We have some custom DAG combine patterns for these nodes
827
827
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
828
828
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
829
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
829
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
830
830
831
831
// setcc for f16x2 and bf16x2 needs special handling to prevent
832
832
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5713,6 +5713,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5713
5713
return SDValue ();
5714
5714
}
5715
5715
5716
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5717
+ TargetLowering::DAGCombinerInfo &DCI) {
5718
+ SDLoc DL (N);
5719
+ SDValue Op = N->getOperand (0 );
5720
+ SDValue Trunc = N->getOperand (1 );
5721
+ EVT NarrowVT = N->getValueType (0 );
5722
+ EVT WideVT = Op.getValueType ();
5723
+
5724
+ // v2[b]f16 = fp_round (v2f32 A)
5725
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5726
+ // ([b]f16 = fp_round (extractelt A, 1)))
5727
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5728
+ WideVT == MVT::v2f32) {
5729
+ SDValue F32Op0, F32Op1;
5730
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5731
+ F32Op0 = Op.getOperand (0 );
5732
+ F32Op1 = Op.getOperand (1 );
5733
+ } else {
5734
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5735
+ DCI.DAG .getIntPtrConstant (0 , DL));
5736
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5737
+ DCI.DAG .getIntPtrConstant (1 , DL));
5738
+ }
5739
+ return DCI.DAG .getBuildVector (
5740
+ NarrowVT, DL,
5741
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5742
+ Trunc),
5743
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5744
+ Trunc)});
5745
+ }
5746
+
5747
+ return SDValue ();
5748
+ }
5749
+
5716
5750
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5717
5751
DAGCombinerInfo &DCI) const {
5718
5752
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5749,6 +5783,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5749
5783
return PerformBUILD_VECTORCombine (N, DCI);
5750
5784
case ISD::ADDRSPACECAST:
5751
5785
return combineADDRSPACECAST (N, DCI);
5786
+ case ISD::FP_ROUND:
5787
+ return PerformFP_ROUNDCombine (N, DCI);
5752
5788
}
5753
5789
return SDValue ();
5754
5790
}
0 commit comments