@@ -551,6 +551,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
551551 setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
552552 setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
553553 setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
554+
555+ // Custom conversions to/from v2i8.
556+ setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
557+
554558 // Only logical ops can be done on v4i8 directly, others must be done
555559 // elementwise.
556560 setOperationAction (
@@ -2311,6 +2315,45 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
23112315 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
23122316}
23132317
2318+ SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2319+ // Handle bitcasting to/from v2i8 without hitting the default promotion
2320+ // strategy which goes through stack memory.
2321+ SDNode *Node = Op.getNode ();
2322+ SDLoc dl (Node);
2323+
2324+ auto maybeBitcast = [&](EVT vt, SDValue val) {
2325+ if (val->getValueType (0 ) == vt) {
2326+ return val;
2327+ }
2328+ return DAG.getNode (ISD::BITCAST, dl, vt, val);
2329+ };
2330+
2331+ EVT VT = Op->getValueType (0 );
2332+ EVT fromVT = Op->getOperand (0 )->getValueType (0 );
2333+
2334+ if (VT == MVT::v2i8) {
2335+ SDValue reg = maybeBitcast (MVT::i16 , Op->getOperand (0 ));
2336+ // Promote result to v2i16
2337+ SDValue v0 = DAG.getNode (ISD::TRUNCATE, dl, MVT::i8 , reg);
2338+ SDValue C8 = DAG.getConstant (8 , dl, MVT::i16 );
2339+ SDValue v1 = DAG.getNode (ISD::TRUNCATE, dl, MVT::i8 ,
2340+ DAG.getNode (ISD::SRL, dl, MVT::i16 , {reg, C8}));
2341+ return DAG.getNode (ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
2342+ } else if (fromVT == MVT::v2i8) {
2343+ SDValue v0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8 , Op->getOperand (0 ),
2344+ DAG.getIntPtrConstant (0 , dl));
2345+ SDValue v1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8 , Op->getOperand (0 ),
2346+ DAG.getIntPtrConstant (1 , dl));
2347+ SDValue E0 = DAG.getNode (ISD::ZERO_EXTEND, dl, MVT::i16 , v0);
2348+ SDValue E1 = DAG.getNode (ISD::ZERO_EXTEND, dl, MVT::i16 , v1);
2349+ SDValue C8 = DAG.getConstant (8 , dl, MVT::i16 );
2350+ SDValue reg = DAG.getNode (ISD::OR, dl, MVT::i16 ,
2351+ {E0 , DAG.getNode (ISD::SHL, dl, MVT::i16 , {E1 , C8})});
2352+ return maybeBitcast (VT, reg);
2353+ }
2354+ return Op;
2355+ }
2356+
23142357// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
23152358// would get lowered as two constant loads and vector-packing move.
23162359// Instead we want just a constant move:
@@ -2818,6 +2861,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28182861 return Op;
28192862 case ISD::BUILD_VECTOR:
28202863 return LowerBUILD_VECTOR (Op, DAG);
2864+ case ISD::BITCAST:
2865+ return LowerBITCAST (Op, DAG);
28212866 case ISD::EXTRACT_SUBVECTOR:
28222867 return Op;
28232868 case ISD::EXTRACT_VECTOR_ELT:
@@ -6413,6 +6458,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
64136458 switch (N->getOpcode ()) {
64146459 default :
64156460 report_fatal_error (" Unhandled custom legalization" );
6461+ case ISD::BITCAST:
6462+ Results.push_back (LowerBITCAST (SDValue (N, 0 ), DAG));
6463+ return ;
64166464 case ISD::LOAD:
64176465 ReplaceLoadVector (N, DAG, Results);
64186466 return ;
0 commit comments