@@ -832,7 +832,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
832832 // We have some custom DAG combine patterns for these nodes
833833 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
835+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836+ ISD::TRUNCATE});
836837
837838 // setcc for f16x2 and bf16x2 needs special handling to prevent
838839 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5732,6 +5733,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57325733 return SDValue ();
57335734}
57345735
5736+ static SDValue PerformTRUNCATECombine (SDNode *N,
5737+ TargetLowering::DAGCombinerInfo &DCI) {
5738+ SDLoc DL (N);
5739+ SDValue Op = N->getOperand (0 );
5740+ EVT FromVT = Op.getValueType ();
5741+ EVT ResultVT = N->getValueType (0 );
5742+
5743+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5744+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5745+ // -> i32 = bitcast (f32 A)
5746+ if (Op.getOpcode () == ISD::BITCAST) {
5747+ SDValue BV = Op.getOperand (0 );
5748+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5749+ BV.getValueType () == MVT::v2f32) {
5750+ // get lower
5751+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5752+ }
5753+ }
5754+
5755+ // i32 = truncate (i64 = srl
5756+ // (i64 = bitcast
5757+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5758+ // -> i32 = bitcast (f32 B)
5759+ if (Op.getOpcode () == ISD::SRL) {
5760+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5761+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5762+ SDValue Cast = Op.getOperand (0 );
5763+ if (Cast.getOpcode () == ISD::BITCAST) {
5764+ SDValue BV = Cast.getOperand (0 );
5765+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5766+ BV.getValueType () == MVT::v2f32) {
5767+ // get upper
5768+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5769+ BV.getOperand (1 ));
5770+ }
5771+ }
5772+ }
5773+ }
5774+ }
5775+
5776+ return SDValue ();
5777+ }
5778+
57355779SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
57365780 DAGCombinerInfo &DCI) const {
57375781 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5770,6 +5814,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57705814 return combineADDRSPACECAST (N, DCI);
57715815 case ISD::FP_ROUND:
57725816 return PerformFP_ROUNDCombine (N, DCI);
5817+ case ISD::TRUNCATE:
5818+ return PerformTRUNCATECombine (N, DCI);
57735819 }
57745820 return SDValue ();
57755821}
0 commit comments