@@ -767,8 +767,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
767767 setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
768768 {MVT::i16 , MVT::i32 , MVT::i64 }, Legal);
769769
770+ setOperationAction ({ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i16 ,
771+ Promote);
770772 setOperationAction ({ISD::CTPOP, ISD::CTLZ}, MVT::i32 , Legal);
771- setOperationAction ({ISD::CTPOP, ISD::CTLZ}, { MVT::i16 , MVT:: i64 } , Custom);
773+ setOperationAction ({ISD::CTPOP, ISD::CTLZ}, MVT::i64 , Custom);
772774
773775 setI16x2OperationAction (ISD::ABS, MVT::v2i16, Legal, Custom);
774776 setI16x2OperationAction (ISD::SMIN, MVT::v2i16, Legal, Custom);
@@ -2743,40 +2745,17 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
27432745 return Op;
27442746}
27452747
2746- static SDValue lowerCTPOP (SDValue Op, SelectionDAG &DAG) {
2748+ // In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
2749+ // Lower these into a node returning the correct type which is zero-extended
2750+ // back to the correct size.
2751+ static SDValue lowerCTLZCTPOP (SDValue Op, SelectionDAG &DAG) {
27472752 SDValue V = Op->getOperand (0 );
2748- SDLoc DL (Op);
2749-
2750- if (V.getValueType () == MVT::i16 ) {
2751- SDValue Zext = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i32 , V);
2752- SDValue CT = DAG.getNode (ISD::CTPOP, DL, MVT::i32 , Zext);
2753- return DAG.getNode (ISD::TRUNCATE, DL, MVT::i16 , CT, SDNodeFlags::NoWrap);
2754- }
2755- if (V.getValueType () == MVT::i64 ) {
2756- SDValue CT = DAG.getNode (ISD::CTPOP, DL, MVT::i32 , V);
2757- return DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i64 , CT);
2758- }
2759- llvm_unreachable (" Unexpected CTPOP type to legalize" );
2760- }
2753+ assert (V.getValueType () == MVT::i64 &&
2754+ " Unexpected CTLZ/CTPOP type to legalize" );
27612755
2762- static SDValue lowerCTLZ (SDValue Op, SelectionDAG &DAG) {
2763- SDValue V = Op->getOperand (0 );
27642756 SDLoc DL (Op);
2765-
2766- if (V.getValueType () == MVT::i16 ) {
2767- SDValue Zext = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i32 , V);
2768- SDValue CT = DAG.getNode (ISD::CTLZ, DL, MVT::i32 , Zext);
2769- SDValue Sub =
2770- DAG.getNode (ISD::ADD, DL, MVT::i32 , CT,
2771- DAG.getConstant (APInt (32 , -16 , true ), DL, MVT::i32 ),
2772- SDNodeFlags::NoSignedWrap);
2773- return DAG.getNode (ISD::TRUNCATE, DL, MVT::i16 , Sub, SDNodeFlags::NoWrap);
2774- }
2775- if (V.getValueType () == MVT::i64 ) {
2776- SDValue CT = DAG.getNode (ISD::CTLZ, DL, MVT::i32 , V);
2777- return DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i64 , CT);
2778- }
2779- llvm_unreachable (" Unexpected CTLZ type to legalize" );
2757+ SDValue CT = DAG.getNode (Op->getOpcode (), DL, MVT::i32 , V);
2758+ return DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i64 , CT, SDNodeFlags::NonNeg);
27802759}
27812760
27822761SDValue
@@ -2865,9 +2844,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28652844 // Used only for bf16 on SM80, where we select fma for non-ftz operation
28662845 return PromoteBinOpIfF32FTZ (Op, DAG);
28672846 case ISD::CTPOP:
2868- return lowerCTPOP (Op, DAG);
28692847 case ISD::CTLZ:
2870- return lowerCTLZ (Op, DAG);
2848+ return lowerCTLZCTPOP (Op, DAG);
28712849
28722850 default :
28732851 llvm_unreachable (" Custom lowering not defined for operation" );
0 commit comments