@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10481048 MVT::v32i32, MVT::v64i32, MVT::v128i32},
10491049 Custom);
10501050
1051- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1052- // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1053- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::i128 , Custom);
1051+ // Enable custom lowering for the following:
1052+ // * MVT::i128 - clusterlaunchcontrol
1053+ // * MVT::i32 - prmt
1054+ // * MVT::Other - internal.addrspace.wrap
1055+ setOperationAction (ISD::INTRINSIC_WO_CHAIN, {MVT::i32 , MVT::i128 , MVT::Other},
1056+ Custom);
10541057}
10551058
10561059const char *NVPTXTargetLowering::getTargetNodeName (unsigned Opcode) const {
@@ -2060,6 +2063,13 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20602063 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
20612064}
20622065
2066+ static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2067+ SelectionDAG &DAG,
2068+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2069+ return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
2070+ {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
2071+ }
2072+
20632073SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
20642074 // Handle bitcasting from v2i8 without hitting the default promotion
20652075 // strategy which goes through stack memory.
@@ -2111,15 +2121,13 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21112121 L = DAG.getAnyExtOrTrunc (L, DL, MVT::i32 );
21122122 R = DAG.getAnyExtOrTrunc (R, DL, MVT::i32 );
21132123 }
2114- return DAG.getNode (
2115- NVPTXISD::PRMT, DL, MVT::v4i8,
2116- {L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ),
2117- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2124+ return getPRMT (L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ), DL,
2125+ DAG);
21182126 };
21192127 auto PRMT__10 = GetPRMT (Op->getOperand (0 ), Op->getOperand (1 ), true , 0x3340 );
21202128 auto PRMT__32 = GetPRMT (Op->getOperand (2 ), Op->getOperand (3 ), true , 0x3340 );
21212129 auto PRMT3210 = GetPRMT (PRMT__10, PRMT__32, false , 0x5410 );
2122- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT3210);
2130+ return DAG.getBitcast ( VT, PRMT3210);
21232131 }
21242132
21252133 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2184,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21762184 SDValue Selector = DAG.getNode (ISD::OR, DL, MVT::i32 ,
21772185 DAG.getZExtOrTrunc (Index, DL, MVT::i32 ),
21782186 DAG.getConstant (0x7770 , DL, MVT::i32 ));
2179- SDValue PRMT = DAG.getNode (
2180- NVPTXISD::PRMT, DL, MVT::i32 ,
2181- {DAG.getBitcast (MVT::i32 , Vector), DAG.getConstant (0 , DL, MVT::i32 ),
2182- Selector, DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2183- return DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2187+ SDValue PRMT = getPRMT (DAG.getBitcast (MVT::i32 , Vector),
2188+ DAG.getConstant (0 , DL, MVT::i32 ), Selector, DL, DAG);
2189+ SDValue Ext = DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2190+ SDNodeFlags Flags;
2191+ Flags.setNoSignedWrap (Ext.getScalarValueSizeInBits () > 8 );
2192+ Flags.setNoUnsignedWrap (Ext.getScalarValueSizeInBits () >= 8 );
2193+ Ext->setFlags (Flags);
2194+ return Ext;
21842195 }
21852196
21862197 // Constant index will be matched by tablegen.
@@ -2242,9 +2253,10 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
22422253 }
22432254
22442255 SDLoc DL (Op);
2245- return DAG.getNode (NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2246- DAG.getConstant (Selector, DL, MVT::i32 ),
2247- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 ));
2256+ SDValue PRMT =
2257+ getPRMT (DAG.getBitcast (MVT::i32 , V1), DAG.getBitcast (MVT::i32 , V2),
2258+ DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG);
2259+ return DAG.getBitcast (Op.getValueType (), PRMT);
22482260}
22492261// / LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
22502262// / 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2741,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
27292741 {TryCancelResponse0, TryCancelResponse1});
27302742}
27312743
2744+ static SDValue lowerPrmtIntrinsic (SDValue Op, SelectionDAG &DAG) {
2745+ const unsigned Mode = [&]() {
2746+ switch (Op->getConstantOperandVal (0 )) {
2747+ case Intrinsic::nvvm_prmt:
2748+ return NVPTX::PTXPrmtMode::NONE;
2749+ case Intrinsic::nvvm_prmt_b4e:
2750+ return NVPTX::PTXPrmtMode::B4E;
2751+ case Intrinsic::nvvm_prmt_ecl:
2752+ return NVPTX::PTXPrmtMode::ECL;
2753+ case Intrinsic::nvvm_prmt_ecr:
2754+ return NVPTX::PTXPrmtMode::ECR;
2755+ case Intrinsic::nvvm_prmt_f4e:
2756+ return NVPTX::PTXPrmtMode::F4E;
2757+ case Intrinsic::nvvm_prmt_rc16:
2758+ return NVPTX::PTXPrmtMode::RC16;
2759+ case Intrinsic::nvvm_prmt_rc8:
2760+ return NVPTX::PTXPrmtMode::RC8;
2761+ default :
2762+ llvm_unreachable (" unsupported/unhandled intrinsic" );
2763+ }
2764+ }();
2765+ SDLoc DL (Op);
2766+ SDValue A = Op->getOperand (1 );
2767+ SDValue B = Op.getNumOperands () == 4 ? Op.getOperand (2 )
2768+ : DAG.getConstant (0 , DL, MVT::i32 );
2769+ SDValue Selector = (Op->op_end () - 1 )->get ();
2770+ return getPRMT (A, B, Selector, DL, DAG, Mode);
2771+ }
27322772static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
27332773 switch (Op->getConstantOperandVal (0 )) {
27342774 default :
27352775 return Op;
2776+ case Intrinsic::nvvm_prmt:
2777+ case Intrinsic::nvvm_prmt_b4e:
2778+ case Intrinsic::nvvm_prmt_ecl:
2779+ case Intrinsic::nvvm_prmt_ecr:
2780+ case Intrinsic::nvvm_prmt_f4e:
2781+ case Intrinsic::nvvm_prmt_rc16:
2782+ case Intrinsic::nvvm_prmt_rc8:
2783+ return lowerPrmtIntrinsic (Op, DAG);
27362784 case Intrinsic::nvvm_internal_addrspace_wrap:
27372785 return Op.getOperand (1 );
27382786 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5823,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
57755823 SDLoc DL (N);
57765824 auto &DAG = DCI.DAG ;
57775825
5778- auto PRMT = DAG.getNode (
5779- NVPTXISD::PRMT, DL, MVT::v4i8,
5780- {Op0, Op1, DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ),
5781- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
5782- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT);
5826+ auto PRMT = getPRMT (
5827+ DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5828+ DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ), DL, DAG);
5829+ return DAG.getBitcast (VT, PRMT);
57835830}
57845831
57855832static SDValue combineADDRSPACECAST (SDNode *N,
@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
57975844 return SDValue ();
57985845}
57995846
5847+ static APInt getPRMTSelector (APInt Selector, unsigned Mode) {
5848+ if (Mode == NVPTX::PTXPrmtMode::NONE)
5849+ return Selector;
5850+
5851+ unsigned V = Selector.trunc (2 ).getZExtValue ();
5852+
5853+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
5854+ unsigned S3) {
5855+ return APInt (32 , S0 | (S1 << 4 ) | (S2 << 8 ) | (S3 << 12 ));
5856+ };
5857+
5858+ switch (Mode) {
5859+ case NVPTX::PTXPrmtMode::F4E:
5860+ return GetSelector (V, V + 1 , V + 2 , V + 3 );
5861+ case NVPTX::PTXPrmtMode::B4E:
5862+ return GetSelector (V, (V - 1 ) & 7 , (V - 2 ) & 7 , (V - 3 ) & 7 );
5863+ case NVPTX::PTXPrmtMode::RC8:
5864+ return GetSelector (V, V, V, V);
5865+ case NVPTX::PTXPrmtMode::ECL:
5866+ return GetSelector (V, std::max (V, 1U ), std::max (V, 2U ), 3U );
5867+ case NVPTX::PTXPrmtMode::ECR:
5868+ return GetSelector (0 , std::min (V, 1U ), std::min (V, 2U ), V);
5869+ case NVPTX::PTXPrmtMode::RC16: {
5870+ unsigned V1 = (V & 1 ) << 1 ;
5871+ return GetSelector (V1, V1 + 1 , V1, V1 + 1 );
5872+ }
5873+ default :
5874+ llvm_unreachable (" Invalid PRMT mode" );
5875+ }
5876+ }
5877+
5878+ static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5879+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5880+ APInt BitField = B.concat (A);
5881+ APInt SelectorVal = getPRMTSelector (Selector, Mode);
5882+ APInt Result (32 , 0 );
5883+ for (unsigned I : llvm::seq (4U )) {
5884+ APInt Sel = SelectorVal.extractBits (4 , I * 4 );
5885+ unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
5886+ unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
5887+ APInt Byte = BitField.extractBits (8 , Idx * 8 );
5888+ if (Sign)
5889+ Byte = Byte.ashr (8 );
5890+ Result.insertBits (Byte, I * 8 );
5891+ }
5892+ return Result;
5893+ }
5894+
5895+ static SDValue combinePRMT (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5896+ CodeGenOptLevel OptLevel) {
5897+ if (OptLevel == CodeGenOptLevel::None)
5898+ return SDValue ();
5899+
5900+ // Constant fold PRMT
5901+ if (isa<ConstantSDNode>(N->getOperand (0 )) &&
5902+ isa<ConstantSDNode>(N->getOperand (1 )) &&
5903+ isa<ConstantSDNode>(N->getOperand (2 )))
5904+ return DCI.DAG .getConstant (computePRMT (N->getConstantOperandAPInt (0 ),
5905+ N->getConstantOperandAPInt (1 ),
5906+ N->getConstantOperandAPInt (2 ),
5907+ N->getConstantOperandVal (3 )),
5908+ SDLoc (N), N->getValueType (0 ));
5909+
5910+ return SDValue ();
5911+ }
5912+
58005913SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
58015914 DAGCombinerInfo &DCI) const {
58025915 CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
58035916 switch (N->getOpcode ()) {
5804- default : break ;
5805- case ISD::ADD:
5806- return PerformADDCombine (N, DCI, OptLevel);
5807- case ISD::FADD:
5808- return PerformFADDCombine (N, DCI, OptLevel);
5809- case ISD::MUL:
5810- return PerformMULCombine (N, DCI, OptLevel);
5811- case ISD::SHL:
5812- return PerformSHLCombine (N, DCI, OptLevel);
5813- case ISD::AND:
5814- return PerformANDCombine (N, DCI);
5815- case ISD::UREM:
5816- case ISD::SREM:
5817- return PerformREMCombine (N, DCI, OptLevel);
5818- case ISD::SETCC:
5819- return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5820- case ISD::LOAD:
5821- case NVPTXISD::LoadParamV2:
5822- case NVPTXISD::LoadV2:
5823- case NVPTXISD::LoadV4:
5824- return combineUnpackingMovIntoLoad (N, DCI);
5825- case NVPTXISD::StoreParam:
5826- case NVPTXISD::StoreParamV2:
5827- case NVPTXISD::StoreParamV4:
5828- return PerformStoreParamCombine (N, DCI);
5829- case ISD::STORE:
5830- case NVPTXISD::StoreV2:
5831- case NVPTXISD::StoreV4:
5832- return PerformStoreCombine (N, DCI);
5833- case ISD::EXTRACT_VECTOR_ELT:
5834- return PerformEXTRACTCombine (N, DCI);
5835- case ISD::VSELECT:
5836- return PerformVSELECTCombine (N, DCI);
5837- case ISD::BUILD_VECTOR:
5838- return PerformBUILD_VECTORCombine (N, DCI);
5839- case ISD::ADDRSPACECAST:
5840- return combineADDRSPACECAST (N, DCI);
5917+ default :
5918+ break ;
5919+ case ISD::ADD:
5920+ return PerformADDCombine (N, DCI, OptLevel);
5921+ case ISD::FADD:
5922+ return PerformFADDCombine (N, DCI, OptLevel);
5923+ case ISD::MUL:
5924+ return PerformMULCombine (N, DCI, OptLevel);
5925+ case ISD::SHL:
5926+ return PerformSHLCombine (N, DCI, OptLevel);
5927+ case ISD::AND:
5928+ return PerformANDCombine (N, DCI);
5929+ case ISD::UREM:
5930+ case ISD::SREM:
5931+ return PerformREMCombine (N, DCI, OptLevel);
5932+ case ISD::SETCC:
5933+ return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5934+ case ISD::LOAD:
5935+ case NVPTXISD::LoadParamV2:
5936+ case NVPTXISD::LoadV2:
5937+ case NVPTXISD::LoadV4:
5938+ return combineUnpackingMovIntoLoad (N, DCI);
5939+ case NVPTXISD::StoreParam:
5940+ case NVPTXISD::StoreParamV2:
5941+ case NVPTXISD::StoreParamV4:
5942+ return PerformStoreParamCombine (N, DCI);
5943+ case ISD::STORE:
5944+ case NVPTXISD::StoreV2:
5945+ case NVPTXISD::StoreV4:
5946+ return PerformStoreCombine (N, DCI);
5947+ case ISD::EXTRACT_VECTOR_ELT:
5948+ return PerformEXTRACTCombine (N, DCI);
5949+ case ISD::VSELECT:
5950+ return PerformVSELECTCombine (N, DCI);
5951+ case ISD::BUILD_VECTOR:
5952+ return PerformBUILD_VECTORCombine (N, DCI);
5953+ case ISD::ADDRSPACECAST:
5954+ return combineADDRSPACECAST (N, DCI);
5955+ case NVPTXISD::PRMT:
5956+ return combinePRMT (N, DCI, OptLevel);
58415957 }
58425958 return SDValue ();
58435959}
@@ -6387,7 +6503,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63876503 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand (2 ));
63886504 unsigned Mode = Op.getConstantOperandVal (3 );
63896505
6390- if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6506+ if (!Selector)
63916507 return ;
63926508
63936509 KnownBits AKnown = DAG.computeKnownBits (A, Depth);
@@ -6396,7 +6512,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63966512 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
63976513 KnownBits BitField = BKnown.concat (AKnown);
63986514
6399- APInt SelectorVal = Selector->getAPIntValue ();
6515+ APInt SelectorVal = getPRMTSelector ( Selector->getAPIntValue (), Mode );
64006516 for (unsigned I : llvm::seq (std::min (4U , Known.getBitWidth () / 8 ))) {
64016517 APInt Sel = SelectorVal.extractBits (4 , I * 4 );
64026518 unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
0 commit comments