@@ -1036,15 +1036,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10361036 setOperationAction (ISD::ADDRSPACECAST, {MVT::i32 , MVT::i64 }, Custom);
10371037
10381038 setOperationAction (ISD::ATOMIC_LOAD_SUB, {MVT::i32 , MVT::i64 }, Expand);
1039- // No FPOW or FREM in PTX.
1039+
1040+ // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1041+ // type, we need to custom lower it.
1042+ setOperationAction ({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128 ,
1043+ Custom);
10401044
10411045 // Now deduce the information based on the above mentioned
10421046 // actions
10431047 computeRegisterProperties (STI.getRegisterInfo ());
10441048
10451049 // PTX support for 16-bit CAS is emulated. Only use 32+
10461050 setMinCmpXchgSizeInBits (STI.getMinCmpXchgSizeInBits ());
1047- setMaxAtomicSizeInBitsSupported (64 );
1051+ setMaxAtomicSizeInBitsSupported (STI. hasAtomSwap128 () ? 128 : 64 );
10481052 setMaxDivRemBitWidthSupported (64 );
10491053
10501054 // Custom lowering for tcgen05.ld vector operands
@@ -1077,6 +1081,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10771081 case NVPTXISD::FIRST_NUMBER:
10781082 break ;
10791083
1084+ MAKE_CASE (NVPTXISD::ATOMIC_CMP_SWAP_B128)
1085+ MAKE_CASE (NVPTXISD::ATOMIC_SWAP_B128)
10801086 MAKE_CASE (NVPTXISD::RET_GLUE)
10811087 MAKE_CASE (NVPTXISD::DeclareArrayParam)
10821088 MAKE_CASE (NVPTXISD::DeclareScalarParam)
@@ -6236,6 +6242,49 @@ static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
62366242 Results.push_back (Res);
62376243}
62386244
6245+ static void replaceAtomicSwap128 (SDNode *N, SelectionDAG &DAG,
6246+ const NVPTXSubtarget &STI,
6247+ SmallVectorImpl<SDValue> &Results) {
6248+ assert (N->getValueType (0 ) == MVT::i128 &&
6249+ " Custom lowering for atomic128 only supports i128" );
6250+
6251+ AtomicSDNode *AN = cast<AtomicSDNode>(N);
6252+ SDLoc dl (N);
6253+
6254+ if (!STI.hasAtomSwap128 ()) {
6255+ DAG.getContext ()->diagnose (DiagnosticInfoUnsupported (
6256+ DAG.getMachineFunction ().getFunction (),
6257+ " Support for b128 atomics introduced in PTX ISA version 8.3 and "
6258+ " requires target sm_90." ,
6259+ dl.getDebugLoc ()));
6260+
6261+ Results.push_back (DAG.getUNDEF (MVT::i128 ));
6262+ Results.push_back (AN->getOperand (0 )); // Chain
6263+ return ;
6264+ }
6265+
6266+ SmallVector<SDValue, 6 > Ops;
6267+ Ops.push_back (AN->getOperand (0 )); // Chain
6268+ Ops.push_back (AN->getOperand (1 )); // Ptr
6269+ for (const auto &Op : AN->ops ().drop_front (2 )) {
6270+ // Low part
6271+ Ops.push_back (DAG.getNode (ISD::EXTRACT_ELEMENT, dl, MVT::i64 , Op,
6272+ DAG.getIntPtrConstant (0 , dl)));
6273+ // High part
6274+ Ops.push_back (DAG.getNode (ISD::EXTRACT_ELEMENT, dl, MVT::i64 , Op,
6275+ DAG.getIntPtrConstant (1 , dl)));
6276+ }
6277+ unsigned Opcode = N->getOpcode () == ISD::ATOMIC_SWAP
6278+ ? NVPTXISD::ATOMIC_SWAP_B128
6279+ : NVPTXISD::ATOMIC_CMP_SWAP_B128;
6280+ SDVTList Tys = DAG.getVTList (MVT::i64 , MVT::i64 , MVT::Other);
6281+ SDValue Result = DAG.getMemIntrinsicNode (Opcode, dl, Tys, Ops, MVT::i128 ,
6282+ AN->getMemOperand ());
6283+ Results.push_back (DAG.getNode (ISD::BUILD_PAIR, dl, MVT::i128 ,
6284+ {Result.getValue (0 ), Result.getValue (1 )}));
6285+ Results.push_back (Result.getValue (2 ));
6286+ }
6287+
62396288void NVPTXTargetLowering::ReplaceNodeResults (
62406289 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
62416290 switch (N->getOpcode ()) {
@@ -6256,6 +6305,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
62566305 case NVPTXISD::ProxyReg:
62576306 replaceProxyReg (N, DAG, *this , Results);
62586307 return ;
6308+ case ISD::ATOMIC_CMP_SWAP:
6309+ case ISD::ATOMIC_SWAP:
6310+ replaceAtomicSwap128 (N, DAG, STI, Results);
6311+ return ;
62596312 }
62606313}
62616314
@@ -6280,16 +6333,19 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
62806333 }
62816334
62826335 assert (Ty->isIntegerTy () && " Ty should be integer at this point" );
6283- auto ITy = cast<llvm:: IntegerType>(Ty);
6336+ const unsigned BitWidth = cast<IntegerType>(Ty)-> getBitWidth ( );
62846337
62856338 switch (AI->getOperation ()) {
62866339 default :
62876340 return AtomicExpansionKind::CmpXChg;
6341+ case AtomicRMWInst::BinOp::Xchg:
6342+ if (BitWidth == 128 )
6343+ return AtomicExpansionKind::None;
6344+ LLVM_FALLTHROUGH;
62886345 case AtomicRMWInst::BinOp::And:
62896346 case AtomicRMWInst::BinOp::Or:
62906347 case AtomicRMWInst::BinOp::Xor:
6291- case AtomicRMWInst::BinOp::Xchg:
6292- switch (ITy->getBitWidth ()) {
6348+ switch (BitWidth) {
62936349 case 8 :
62946350 case 16 :
62956351 return AtomicExpansionKind::CmpXChg;
@@ -6299,6 +6355,8 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
62996355 if (STI.hasAtomBitwise64 ())
63006356 return AtomicExpansionKind::None;
63016357 return AtomicExpansionKind::CmpXChg;
6358+ case 128 :
6359+ return AtomicExpansionKind::CmpXChg;
63026360 default :
63036361 llvm_unreachable (" unsupported width encountered" );
63046362 }
@@ -6308,7 +6366,7 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
63086366 case AtomicRMWInst::BinOp::Min:
63096367 case AtomicRMWInst::BinOp::UMax:
63106368 case AtomicRMWInst::BinOp::UMin:
6311- switch (ITy-> getBitWidth () ) {
6369+ switch (BitWidth ) {
63126370 case 8 :
63136371 case 16 :
63146372 return AtomicExpansionKind::CmpXChg;
@@ -6318,17 +6376,20 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
63186376 if (STI.hasAtomMinMax64 ())
63196377 return AtomicExpansionKind::None;
63206378 return AtomicExpansionKind::CmpXChg;
6379+ case 128 :
6380+ return AtomicExpansionKind::CmpXChg;
63216381 default :
63226382 llvm_unreachable (" unsupported width encountered" );
63236383 }
63246384 case AtomicRMWInst::BinOp::UIncWrap:
63256385 case AtomicRMWInst::BinOp::UDecWrap:
6326- switch (ITy-> getBitWidth () ) {
6386+ switch (BitWidth ) {
63276387 case 32 :
63286388 return AtomicExpansionKind::None;
63296389 case 8 :
63306390 case 16 :
63316391 case 64 :
6392+ case 128 :
63326393 return AtomicExpansionKind::CmpXChg;
63336394 default :
63346395 llvm_unreachable (" unsupported width encountered" );
0 commit comments