@@ -872,6 +872,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
872872 setOperationAction (Op, MVT::v2f32, Custom);
873873 // Handle custom lowering for: i64 = bitcast v2f32
874874 setOperationAction (ISD::BITCAST, MVT::v2f32, Custom);
875+ // Handle custom lowering for: f32 = extract_vector_elt v2f32
876+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
875877 }
876878
877879 // These map to conversion instructions for scalar FP types.
@@ -2253,6 +2255,20 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
22532255 return DAG.getAnyExtOrTrunc (BFE, DL, Op->getValueType (0 ));
22542256 }
22552257
2258+ if (VectorVT == MVT::v2f32) {
2259+ if (Vector.getOpcode () == ISD::BITCAST) {
2260+ // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
2261+ // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
2262+ SDValue Pair = Vector.getOperand (0 );
2263+ assert (Pair.getOpcode () == ISD::BUILD_PAIR);
2264+ return DAG.getNode (
2265+ ISD::BITCAST, DL, Op.getValueType (),
2266+ Pair.getOperand (cast<ConstantSDNode>(Index)->getZExtValue ()));
2267+ }
2268+ if (Vector.getOpcode () == ISD::BUILD_VECTOR)
2269+ return Vector.getOperand (cast<ConstantSDNode>(Index)->getZExtValue ());
2270+ }
2271+
22562272 // Constant index will be matched by tablegen.
22572273 if (isa<ConstantSDNode>(Index.getNode ()))
22582274 return Op;
@@ -5565,9 +5581,22 @@ static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
55655581 for (const SDValue &Op : N->ops ())
55665582 NewOps.push_back (DAG.getNode (ISD::BITCAST, DL, MVT::i64 , Op));
55675583
5568- // cast i64 result of new op back to <2 x float>
5584+ SDValue Chain = DAG.getEntryNode ();
5585+
5586+ // break i64 result into two i32 registers for later instructions that may
5587+ // access element #0 or #1. otherwise, this code will be eliminated
55695588 SDValue NewValue = DAG.getNode (Opcode, DL, MVT::i64 , NewOps);
5570- Results.push_back (DAG.getBitcast (OldResultTy, NewValue));
5589+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction ().getRegInfo ();
5590+ Register DestReg = RegInfo.createVirtualRegister (
5591+ DAG.getTargetLoweringInfo ().getRegClassFor (MVT::i64 ));
5592+ SDValue RegCopy = DAG.getCopyToReg (Chain, DL, DestReg, NewValue);
5593+ SDValue Explode = DAG.getNode (ISD::CopyFromReg, DL,
5594+ {MVT::i32 , MVT::i32 , Chain.getValueType ()},
5595+ {RegCopy, DAG.getRegister (DestReg, MVT::i64 )});
5596+ // cast i64 result of new op back to <2 x float>
5597+ Results.push_back (DAG.getBitcast (
5598+ OldResultTy, DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i64 ,
5599+ {Explode.getValue (0 ), Explode.getValue (1 )})));
55715600}
55725601
55735602void NVPTXTargetLowering::ReplaceNodeResults (
0 commit comments