@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10481048 MVT::v32i32, MVT::v64i32, MVT::v128i32},
10491049 Custom);
10501050
1051- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1052- // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1053- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::i128 , Custom);
1051+ // Enable custom lowering for the following:
1052+ // * MVT::i128 - clusterlaunchcontrol
1053+ // * MVT::i32 - prmt
1054+ // * MVT::Other - internal.addrspace.wrap
1055+ setOperationAction (ISD::INTRINSIC_WO_CHAIN, {MVT::i32 , MVT::i128 , MVT::Other},
1056+ Custom);
10541057}
10551058
10561059const char *NVPTXTargetLowering::getTargetNodeName (unsigned Opcode) const {
@@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20602063 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
20612064}
20622065
2066+ static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2067+ SelectionDAG &DAG,
2068+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2069+ return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
2070+ {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
2071+ }
2072+
2073+ static SDValue getPRMT (SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
2074+ SelectionDAG &DAG,
2075+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2076+ return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
2077+ }
2078+
20632079SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
20642080 // Handle bitcasting from v2i8 without hitting the default promotion
20652081 // strategy which goes through stack memory.
@@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21112127 L = DAG.getAnyExtOrTrunc (L, DL, MVT::i32 );
21122128 R = DAG.getAnyExtOrTrunc (R, DL, MVT::i32 );
21132129 }
2114- return DAG.getNode (
2115- NVPTXISD::PRMT, DL, MVT::v4i8,
2116- {L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ),
2117- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2130+ return getPRMT (L, R, SelectionValue, DL, DAG);
21182131 };
21192132 auto PRMT__10 = GetPRMT (Op->getOperand (0 ), Op->getOperand (1 ), true , 0x3340 );
21202133 auto PRMT__32 = GetPRMT (Op->getOperand (2 ), Op->getOperand (3 ), true , 0x3340 );
21212134 auto PRMT3210 = GetPRMT (PRMT__10, PRMT__32, false , 0x5410 );
2122- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT3210);
2135+ return DAG.getBitcast ( VT, PRMT3210);
21232136 }
21242137
21252138 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21762189 SDValue Selector = DAG.getNode (ISD::OR, DL, MVT::i32 ,
21772190 DAG.getZExtOrTrunc (Index, DL, MVT::i32 ),
21782191 DAG.getConstant (0x7770 , DL, MVT::i32 ));
2179- SDValue PRMT = DAG.getNode (
2180- NVPTXISD::PRMT, DL, MVT::i32 ,
2181- {DAG.getBitcast (MVT::i32 , Vector), DAG.getConstant (0 , DL, MVT::i32 ),
2182- Selector, DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2183- return DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2192+ SDValue PRMT = getPRMT (DAG.getBitcast (MVT::i32 , Vector),
2193+ DAG.getConstant (0 , DL, MVT::i32 ), Selector, DL, DAG);
2194+ SDValue Ext = DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2195+ SDNodeFlags Flags;
2196+ Flags.setNoSignedWrap (Ext.getScalarValueSizeInBits () > 8 );
2197+ Flags.setNoUnsignedWrap (Ext.getScalarValueSizeInBits () >= 8 );
2198+ Ext->setFlags (Flags);
2199+ return Ext;
21842200 }
21852201
21862202 // Constant index will be matched by tablegen.
@@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
22422258 }
22432259
22442260 SDLoc DL (Op);
2245- return DAG. getNode (NVPTXISD:: PRMT, DL, MVT::v4i8 , V1, V2 ,
2246- DAG.getConstant ( Selector, DL, MVT:: i32 ),
2247- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT:: i32 ) );
2261+ SDValue PRMT = getPRMT (DAG. getBitcast ( MVT::i32 , V1) ,
2262+ DAG.getBitcast (MVT:: i32 , V2), Selector, DL, DAG);
2263+ return DAG.getBitcast (Op. getValueType (), PRMT );
22482264}
22492265// / LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
22502266// / 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
27292745 {TryCancelResponse0, TryCancelResponse1});
27302746}
27312747
2748+ static SDValue lowerPrmtIntrinsic (SDValue Op, SelectionDAG &DAG) {
2749+ const unsigned Mode = [&]() {
2750+ switch (Op->getConstantOperandVal (0 )) {
2751+ case Intrinsic::nvvm_prmt:
2752+ return NVPTX::PTXPrmtMode::NONE;
2753+ case Intrinsic::nvvm_prmt_b4e:
2754+ return NVPTX::PTXPrmtMode::B4E;
2755+ case Intrinsic::nvvm_prmt_ecl:
2756+ return NVPTX::PTXPrmtMode::ECL;
2757+ case Intrinsic::nvvm_prmt_ecr:
2758+ return NVPTX::PTXPrmtMode::ECR;
2759+ case Intrinsic::nvvm_prmt_f4e:
2760+ return NVPTX::PTXPrmtMode::F4E;
2761+ case Intrinsic::nvvm_prmt_rc16:
2762+ return NVPTX::PTXPrmtMode::RC16;
2763+ case Intrinsic::nvvm_prmt_rc8:
2764+ return NVPTX::PTXPrmtMode::RC8;
2765+ default :
2766+ llvm_unreachable (" unsupported/unhandled intrinsic" );
2767+ }
2768+ }();
2769+ SDLoc DL (Op);
2770+ SDValue A = Op->getOperand (1 );
2771+ SDValue B = Op.getNumOperands () == 4 ? Op.getOperand (2 )
2772+ : DAG.getConstant (0 , DL, MVT::i32 );
2773+ SDValue Selector = (Op->op_end () - 1 )->get ();
2774+ return getPRMT (A, B, Selector, DL, DAG, Mode);
2775+ }
27322776static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
27332777 switch (Op->getConstantOperandVal (0 )) {
27342778 default :
27352779 return Op;
2780+ case Intrinsic::nvvm_prmt:
2781+ case Intrinsic::nvvm_prmt_b4e:
2782+ case Intrinsic::nvvm_prmt_ecl:
2783+ case Intrinsic::nvvm_prmt_ecr:
2784+ case Intrinsic::nvvm_prmt_f4e:
2785+ case Intrinsic::nvvm_prmt_rc16:
2786+ case Intrinsic::nvvm_prmt_rc8:
2787+ return lowerPrmtIntrinsic (Op, DAG);
27362788 case Intrinsic::nvvm_internal_addrspace_wrap:
27372789 return Op.getOperand (1 );
27382790 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
57755827 SDLoc DL (N);
57765828 auto &DAG = DCI.DAG ;
57775829
5778- auto PRMT = DAG.getNode (
5779- NVPTXISD::PRMT, DL, MVT::v4i8,
5780- {Op0, Op1, DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ),
5781- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
5782- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT);
5830+ auto PRMT =
5831+ getPRMT (DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5832+ (Op1Bytes << 8 ) | Op0Bytes, DL, DAG);
5833+ return DAG.getBitcast (VT, PRMT);
57835834}
57845835
57855836static SDValue combineADDRSPACECAST (SDNode *N,
@@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N,
57975848 return SDValue ();
57985849}
57995850
5851+ // Given a constant selector value and a prmt mode, return the selector value
5852+ // normalized to the generic prmt mode. See the PTX ISA documentation for more
5853+ // details:
5854+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5855+ static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
5856+ if (Mode == NVPTX::PTXPrmtMode::NONE)
5857+ return Selector;
5858+
5859+ const unsigned V = Selector.trunc (2 ).getZExtValue ();
5860+
5861+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
5862+ unsigned S3) {
5863+ return APInt (32 , S0 | (S1 << 4 ) | (S2 << 8 ) | (S3 << 12 ));
5864+ };
5865+
5866+ switch (Mode) {
5867+ case NVPTX::PTXPrmtMode::F4E:
5868+ return GetSelector (V, V + 1 , V + 2 , V + 3 );
5869+ case NVPTX::PTXPrmtMode::B4E:
5870+ return GetSelector (V, (V - 1 ) & 7 , (V - 2 ) & 7 , (V - 3 ) & 7 );
5871+ case NVPTX::PTXPrmtMode::RC8:
5872+ return GetSelector (V, V, V, V);
5873+ case NVPTX::PTXPrmtMode::ECL:
5874+ return GetSelector (V, std::max (V, 1U ), std::max (V, 2U ), 3U );
5875+ case NVPTX::PTXPrmtMode::ECR:
5876+ return GetSelector (0 , std::min (V, 1U ), std::min (V, 2U ), V);
5877+ case NVPTX::PTXPrmtMode::RC16: {
5878+ unsigned V1 = (V & 1 ) << 1 ;
5879+ return GetSelector (V1, V1 + 1 , V1, V1 + 1 );
5880+ }
5881+ default :
5882+ llvm_unreachable (" Invalid PRMT mode" );
5883+ }
5884+ }
5885+
5886+ static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5887+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5888+ APInt BitField = B.concat (A);
5889+ APInt SelectorVal = getPRMTSelector (Selector, Mode);
5890+ APInt Result (32 , 0 );
5891+ for (unsigned I : llvm::seq (4U )) {
5892+ APInt Sel = SelectorVal.extractBits (4 , I * 4 );
5893+ unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
5894+ unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
5895+ APInt Byte = BitField.extractBits (8 , Idx * 8 );
5896+ if (Sign)
5897+ Byte = Byte.ashr (8 );
5898+ Result.insertBits (Byte, I * 8 );
5899+ }
5900+ return Result;
5901+ }
5902+
5903+ static SDValue combinePRMT (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5904+ CodeGenOptLevel OptLevel) {
5905+ if (OptLevel == CodeGenOptLevel::None)
5906+ return SDValue ();
5907+
5908+ // Constant fold PRMT
5909+ if (isa<ConstantSDNode>(N->getOperand (0 )) &&
5910+ isa<ConstantSDNode>(N->getOperand (1 )) &&
5911+ isa<ConstantSDNode>(N->getOperand (2 )))
5912+ return DCI.DAG .getConstant (computePRMT (N->getConstantOperandAPInt (0 ),
5913+ N->getConstantOperandAPInt (1 ),
5914+ N->getConstantOperandAPInt (2 ),
5915+ N->getConstantOperandVal (3 )),
5916+ SDLoc (N), N->getValueType (0 ));
5917+
5918+ return SDValue ();
5919+ }
5920+
58005921SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
58015922 DAGCombinerInfo &DCI) const {
58025923 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
58035924 switch (N->getOpcode ()) {
5804- default : break ;
5805- case ISD::ADD:
5806- return PerformADDCombine (N, DCI, OptLevel);
5807- case ISD::FADD:
5808- return PerformFADDCombine (N, DCI, OptLevel);
5809- case ISD::MUL:
5810- return PerformMULCombine (N, DCI, OptLevel);
5811- case ISD::SHL:
5812- return PerformSHLCombine (N, DCI, OptLevel);
5813- case ISD::AND:
5814- return PerformANDCombine (N, DCI);
5815- case ISD::UREM:
5816- case ISD::SREM:
5817- return PerformREMCombine (N, DCI, OptLevel);
5818- case ISD::SETCC:
5819- return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5820- case ISD::LOAD:
5821- case NVPTXISD::LoadParamV2:
5822- case NVPTXISD::LoadV2:
5823- case NVPTXISD::LoadV4:
5824- return combineUnpackingMovIntoLoad (N, DCI);
5825- case NVPTXISD::StoreParam:
5826- case NVPTXISD::StoreParamV2:
5827- case NVPTXISD::StoreParamV4:
5828- return PerformStoreParamCombine (N, DCI);
5829- case ISD::STORE:
5830- case NVPTXISD::StoreV2:
5831- case NVPTXISD::StoreV4:
5832- return PerformStoreCombine (N, DCI);
5833- case ISD::EXTRACT_VECTOR_ELT:
5834- return PerformEXTRACTCombine (N, DCI);
5835- case ISD::VSELECT:
5836- return PerformVSELECTCombine (N, DCI);
5837- case ISD::BUILD_VECTOR:
5838- return PerformBUILD_VECTORCombine (N, DCI);
5839- case ISD::ADDRSPACECAST:
5840- return combineADDRSPACECAST (N, DCI);
5925+ default :
5926+ break ;
5927+ case ISD::ADD:
5928+ return PerformADDCombine (N, DCI, OptLevel);
5929+ case ISD::ADDRSPACECAST:
5930+ return combineADDRSPACECAST (N, DCI);
5931+ case ISD::AND:
5932+ return PerformANDCombine (N, DCI);
5933+ case ISD::BUILD_VECTOR:
5934+ return PerformBUILD_VECTORCombine (N, DCI);
5935+ case ISD::EXTRACT_VECTOR_ELT:
5936+ return PerformEXTRACTCombine (N, DCI);
5937+ case ISD::FADD:
5938+ return PerformFADDCombine (N, DCI, OptLevel);
5939+ case ISD::LOAD:
5940+ case NVPTXISD::LoadParamV2:
5941+ case NVPTXISD::LoadV2:
5942+ case NVPTXISD::LoadV4:
5943+ return combineUnpackingMovIntoLoad (N, DCI);
5944+ case ISD::MUL:
5945+ return PerformMULCombine (N, DCI, OptLevel);
5946+ case NVPTXISD::PRMT:
5947+ return combinePRMT (N, DCI, OptLevel);
5948+ case ISD::SETCC:
5949+ return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5950+ case ISD::SHL:
5951+ return PerformSHLCombine (N, DCI, OptLevel);
5952+ case ISD::SREM:
5953+ case ISD::UREM:
5954+ return PerformREMCombine (N, DCI, OptLevel);
5955+ case NVPTXISD::StoreParam:
5956+ case NVPTXISD::StoreParamV2:
5957+ case NVPTXISD::StoreParamV4:
5958+ return PerformStoreParamCombine (N, DCI);
5959+ case ISD::STORE:
5960+ case NVPTXISD::StoreV2:
5961+ case NVPTXISD::StoreV4:
5962+ return PerformStoreCombine (N, DCI);
5963+ case ISD::VSELECT:
5964+ return PerformVSELECTCombine (N, DCI);
58415965 }
58425966 return SDValue ();
58435967}
@@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63876511 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand (2 ));
63886512 unsigned Mode = Op.getConstantOperandVal (3 );
63896513
6390- if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6514+ if (!Selector)
63916515 return ;
63926516
63936517 KnownBits AKnown = DAG.computeKnownBits (A, Depth);
@@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63966520 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
63976521 KnownBits BitField = BKnown.concat (AKnown);
63986522
6399- APInt SelectorVal = Selector->getAPIntValue ();
6523+ APInt SelectorVal = getPRMTSelector ( Selector->getAPIntValue (), Mode );
64006524 for (unsigned I : llvm::seq (std::min (4U , Known.getBitWidth () / 8 ))) {
64016525 APInt Sel = SelectorVal.extractBits (4 , I * 4 );
64026526 unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
0 commit comments