@@ -828,7 +828,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828828 // We have some custom DAG combine patterns for these nodes
829829 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830830 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
831+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832+ ISD::TRUNCATE});
832833
833834 // setcc for f16x2 and bf16x2 needs special handling to prevent
834835 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5752,6 +5753,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57525753 return SDValue ();
57535754}
57545755
5756+ static SDValue PerformTRUNCATECombine (SDNode *N,
5757+ TargetLowering::DAGCombinerInfo &DCI) {
5758+ SDLoc DL (N);
5759+ SDValue Op = N->getOperand (0 );
5760+ EVT FromVT = Op.getValueType ();
5761+ EVT ResultVT = N->getValueType (0 );
5762+
5763+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5764+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5765+ // -> i32 = bitcast (f32 A)
5766+ if (Op.getOpcode () == ISD::BITCAST) {
5767+ SDValue BV = Op.getOperand (0 );
5768+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5769+ BV.getValueType () == MVT::v2f32) {
5770+ // get lower
5771+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5772+ }
5773+ }
5774+
5775+ // i32 = truncate (i64 = srl
5776+ // (i64 = bitcast
5777+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5778+ // -> i32 = bitcast (f32 B)
5779+ if (Op.getOpcode () == ISD::SRL) {
5780+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5781+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5782+ SDValue Cast = Op.getOperand (0 );
5783+ if (Cast.getOpcode () == ISD::BITCAST) {
5784+ SDValue BV = Cast.getOperand (0 );
5785+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5786+ BV.getValueType () == MVT::v2f32) {
5787+ // get upper
5788+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5789+ BV.getOperand (1 ));
5790+ }
5791+ }
5792+ }
5793+ }
5794+ }
5795+
5796+ return SDValue ();
5797+ }
5798+
57555799SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
57565800 DAGCombinerInfo &DCI) const {
57575801 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5790,6 +5834,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57905834 return combineADDRSPACECAST (N, DCI);
57915835 case ISD::FP_ROUND:
57925836 return PerformFP_ROUNDCombine (N, DCI);
5837+ case ISD::TRUNCATE:
5838+ return PerformTRUNCATECombine (N, DCI);
57935839 }
57945840 return SDValue ();
57955841}
0 commit comments