@@ -828,7 +828,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828
828
// We have some custom DAG combine patterns for these nodes
829
829
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830
830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
831
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832
+ ISD::TRUNCATE});
832
833
833
834
// setcc for f16x2 and bf16x2 needs special handling to prevent
834
835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5752,6 +5753,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
5752
5753
return SDValue ();
5753
5754
}
5754
5755
5756
+ static SDValue PerformTRUNCATECombine (SDNode *N,
5757
+ TargetLowering::DAGCombinerInfo &DCI) {
5758
+ SDLoc DL (N);
5759
+ SDValue Op = N->getOperand (0 );
5760
+ EVT FromVT = Op.getValueType ();
5761
+ EVT ResultVT = N->getValueType (0 );
5762
+
5763
+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5764
+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5765
+ // -> i32 = bitcast (f32 A)
5766
+ if (Op.getOpcode () == ISD::BITCAST) {
5767
+ SDValue BV = Op.getOperand (0 );
5768
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5769
+ BV.getValueType () == MVT::v2f32) {
5770
+ // get lower
5771
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5772
+ }
5773
+ }
5774
+
5775
+ // i32 = truncate (i64 = srl
5776
+ // (i64 = bitcast
5777
+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5778
+ // -> i32 = bitcast (f32 B)
5779
+ if (Op.getOpcode () == ISD::SRL) {
5780
+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5781
+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5782
+ SDValue Cast = Op.getOperand (0 );
5783
+ if (Cast.getOpcode () == ISD::BITCAST) {
5784
+ SDValue BV = Cast.getOperand (0 );
5785
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5786
+ BV.getValueType () == MVT::v2f32) {
5787
+ // get upper
5788
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5789
+ BV.getOperand (1 ));
5790
+ }
5791
+ }
5792
+ }
5793
+ }
5794
+ }
5795
+
5796
+ return SDValue ();
5797
+ }
5798
+
5755
5799
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5756
5800
DAGCombinerInfo &DCI) const {
5757
5801
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5790,6 +5834,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5790
5834
return combineADDRSPACECAST (N, DCI);
5791
5835
case ISD::FP_ROUND:
5792
5836
return PerformFP_ROUNDCombine (N, DCI);
5837
+ case ISD::TRUNCATE:
5838
+ return PerformTRUNCATECombine (N, DCI);
5793
5839
}
5794
5840
return SDValue ();
5795
5841
}
0 commit comments