Skip to content

Commit d254aed

Browse files
authored
[NVPTX] add support for 128-bit atomics (llvm#154852)
1 parent 8ec4db5 commit d254aed

File tree

9 files changed

+1163
-22
lines changed

9 files changed

+1163
-22
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
170170
}
171171
break;
172172
}
173+
case NVPTXISD::ATOMIC_CMP_SWAP_B128:
174+
case NVPTXISD::ATOMIC_SWAP_B128:
175+
selectAtomicSwap128(N);
176+
return;
173177
case ISD::FADD:
174178
case ISD::FMUL:
175179
case ISD::FSUB:
@@ -2337,3 +2341,28 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
23372341
}
23382342
}
23392343
}
2344+
2345+
void NVPTXDAGToDAGISel::selectAtomicSwap128(SDNode *N) {
2346+
MemSDNode *AN = cast<MemSDNode>(N);
2347+
SDLoc dl(N);
2348+
2349+
const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
2350+
SmallVector<SDValue, 5> Ops{Base, Offset};
2351+
Ops.append(N->op_begin() + 2, N->op_end());
2352+
Ops.append({
2353+
getI32Imm(getMemOrder(AN), dl),
2354+
getI32Imm(getAtomicScope(AN), dl),
2355+
getI32Imm(getAddrSpace(AN), dl),
2356+
});
2357+
2358+
assert(N->getOpcode() == NVPTXISD::ATOMIC_CMP_SWAP_B128 ||
2359+
N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128);
2360+
unsigned Opcode = N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128
2361+
? NVPTX::ATOM_EXCH_B128
2362+
: NVPTX::ATOM_CAS_B128;
2363+
2364+
auto *ATOM = CurDAG->getMachineNode(Opcode, dl, N->getVTList(), Ops);
2365+
CurDAG->setNodeMemRefs(ATOM, AN->getMemOperand());
2366+
2367+
ReplaceNode(N, ATOM);
2368+
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
9090
bool IsIm2Col = false);
9191
void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
9292
void SelectTcgen05St(SDNode *N, bool hasOffset = false);
93+
void selectAtomicSwap128(SDNode *N);
9394

9495
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
9596
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10361036
setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
10371037

10381038
setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand);
1039-
// No FPOW or FREM in PTX.
1039+
1040+
// atom.b128 is legal in PTX but since we don't represent i128 as a legal
1041+
// type, we need to custom lower it.
1042+
setOperationAction({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128,
1043+
Custom);
10401044

10411045
// Now deduce the information based on the above mentioned
10421046
// actions
10431047
computeRegisterProperties(STI.getRegisterInfo());
10441048

10451049
// PTX support for 16-bit CAS is emulated. Only use 32+
10461050
setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1047-
setMaxAtomicSizeInBitsSupported(64);
1051+
setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
10481052
setMaxDivRemBitWidthSupported(64);
10491053

10501054
// Custom lowering for tcgen05.ld vector operands
@@ -1077,6 +1081,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10771081
case NVPTXISD::FIRST_NUMBER:
10781082
break;
10791083

1084+
MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128)
1085+
MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128)
10801086
MAKE_CASE(NVPTXISD::RET_GLUE)
10811087
MAKE_CASE(NVPTXISD::DeclareArrayParam)
10821088
MAKE_CASE(NVPTXISD::DeclareScalarParam)
@@ -6236,6 +6242,49 @@ static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
62366242
Results.push_back(Res);
62376243
}
62386244

6245+
static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
6246+
const NVPTXSubtarget &STI,
6247+
SmallVectorImpl<SDValue> &Results) {
6248+
assert(N->getValueType(0) == MVT::i128 &&
6249+
"Custom lowering for atomic128 only supports i128");
6250+
6251+
AtomicSDNode *AN = cast<AtomicSDNode>(N);
6252+
SDLoc dl(N);
6253+
6254+
if (!STI.hasAtomSwap128()) {
6255+
DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
6256+
DAG.getMachineFunction().getFunction(),
6257+
"Support for b128 atomics introduced in PTX ISA version 8.3 and "
6258+
"requires target sm_90.",
6259+
dl.getDebugLoc()));
6260+
6261+
Results.push_back(DAG.getUNDEF(MVT::i128));
6262+
Results.push_back(AN->getOperand(0)); // Chain
6263+
return;
6264+
}
6265+
6266+
SmallVector<SDValue, 6> Ops;
6267+
Ops.push_back(AN->getOperand(0)); // Chain
6268+
Ops.push_back(AN->getOperand(1)); // Ptr
6269+
for (const auto &Op : AN->ops().drop_front(2)) {
6270+
// Low part
6271+
Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
6272+
DAG.getIntPtrConstant(0, dl)));
6273+
// High part
6274+
Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
6275+
DAG.getIntPtrConstant(1, dl)));
6276+
}
6277+
unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
6278+
? NVPTXISD::ATOMIC_SWAP_B128
6279+
: NVPTXISD::ATOMIC_CMP_SWAP_B128;
6280+
SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other);
6281+
SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128,
6282+
AN->getMemOperand());
6283+
Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128,
6284+
{Result.getValue(0), Result.getValue(1)}));
6285+
Results.push_back(Result.getValue(2));
6286+
}
6287+
62396288
void NVPTXTargetLowering::ReplaceNodeResults(
62406289
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
62416290
switch (N->getOpcode()) {
@@ -6256,6 +6305,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
62566305
case NVPTXISD::ProxyReg:
62576306
replaceProxyReg(N, DAG, *this, Results);
62586307
return;
6308+
case ISD::ATOMIC_CMP_SWAP:
6309+
case ISD::ATOMIC_SWAP:
6310+
replaceAtomicSwap128(N, DAG, STI, Results);
6311+
return;
62596312
}
62606313
}
62616314

@@ -6280,16 +6333,19 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
62806333
}
62816334

62826335
assert(Ty->isIntegerTy() && "Ty should be integer at this point");
6283-
auto ITy = cast<llvm::IntegerType>(Ty);
6336+
const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth();
62846337

62856338
switch (AI->getOperation()) {
62866339
default:
62876340
return AtomicExpansionKind::CmpXChg;
6341+
case AtomicRMWInst::BinOp::Xchg:
6342+
if (BitWidth == 128)
6343+
return AtomicExpansionKind::None;
6344+
LLVM_FALLTHROUGH;
62886345
case AtomicRMWInst::BinOp::And:
62896346
case AtomicRMWInst::BinOp::Or:
62906347
case AtomicRMWInst::BinOp::Xor:
6291-
case AtomicRMWInst::BinOp::Xchg:
6292-
switch (ITy->getBitWidth()) {
6348+
switch (BitWidth) {
62936349
case 8:
62946350
case 16:
62956351
return AtomicExpansionKind::CmpXChg;
@@ -6299,6 +6355,8 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
62996355
if (STI.hasAtomBitwise64())
63006356
return AtomicExpansionKind::None;
63016357
return AtomicExpansionKind::CmpXChg;
6358+
case 128:
6359+
return AtomicExpansionKind::CmpXChg;
63026360
default:
63036361
llvm_unreachable("unsupported width encountered");
63046362
}
@@ -6308,7 +6366,7 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
63086366
case AtomicRMWInst::BinOp::Min:
63096367
case AtomicRMWInst::BinOp::UMax:
63106368
case AtomicRMWInst::BinOp::UMin:
6311-
switch (ITy->getBitWidth()) {
6369+
switch (BitWidth) {
63126370
case 8:
63136371
case 16:
63146372
return AtomicExpansionKind::CmpXChg;
@@ -6318,17 +6376,20 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
63186376
if (STI.hasAtomMinMax64())
63196377
return AtomicExpansionKind::None;
63206378
return AtomicExpansionKind::CmpXChg;
6379+
case 128:
6380+
return AtomicExpansionKind::CmpXChg;
63216381
default:
63226382
llvm_unreachable("unsupported width encountered");
63236383
}
63246384
case AtomicRMWInst::BinOp::UIncWrap:
63256385
case AtomicRMWInst::BinOp::UDecWrap:
6326-
switch (ITy->getBitWidth()) {
6386+
switch (BitWidth) {
63276387
case 32:
63286388
return AtomicExpansionKind::None;
63296389
case 8:
63306390
case 16:
63316391
case 64:
6392+
case 128:
63326393
return AtomicExpansionKind::CmpXChg;
63336394
default:
63346395
llvm_unreachable("unsupported width encountered");

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,17 @@ enum NodeType : unsigned {
8181
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
8282

8383
FIRST_MEMORY_OPCODE,
84-
LoadV2 = FIRST_MEMORY_OPCODE,
84+
85+
/// These nodes are used to lower atomic instructions with i128 type. They are
86+
/// similar to the generic nodes, but the input and output values are split
87+
/// into two 64-bit values.
88+
/// ValLo, ValHi, OUTCHAIN = ATOMIC_CMP_SWAP_B128(INCHAIN, ptr, cmpLo, cmpHi,
89+
/// swapLo, swapHi)
90+
/// ValLo, ValHi, OUTCHAIN = ATOMIC_SWAP_B128(INCHAIN, ptr, amtLo, amtHi)
91+
ATOMIC_CMP_SWAP_B128 = FIRST_MEMORY_OPCODE,
92+
ATOMIC_SWAP_B128,
93+
94+
LoadV2,
8595
LoadV4,
8696
LoadV8,
8797
LDUV2, // LDU.v2

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def hasAtomAddF64 : Predicate<"Subtarget->hasAtomAddF64()">;
104104
def hasAtomScope : Predicate<"Subtarget->hasAtomScope()">;
105105
def hasAtomBitwise64 : Predicate<"Subtarget->hasAtomBitwise64()">;
106106
def hasAtomMinMax64 : Predicate<"Subtarget->hasAtomMinMax64()">;
107+
def hasAtomSwap128 : Predicate<"Subtarget->hasAtomSwap128()">;
107108
def hasClusters : Predicate<"Subtarget->hasClusters()">;
108109
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
109110
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,19 +1990,23 @@ multiclass F_ATOMIC_3<RegTyInfo t, string op_str, SDPatternOperator op, SDNode a
19901990

19911991
let mayLoad = 1, mayStore = 1, hasSideEffects = 1 in {
19921992
def _rr : BasicFlagsNVPTXInst<(outs t.RC:$dst),
1993-
(ins ADDR:$addr, t.RC:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
1993+
(ins ADDR:$addr, t.RC:$b, t.RC:$c),
1994+
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
19941995
asm_str>;
19951996

19961997
def _ir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
1997-
(ins ADDR:$addr, t.Imm:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
1998+
(ins ADDR:$addr, t.Imm:$b, t.RC:$c),
1999+
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
19982000
asm_str>;
19992001

20002002
def _ri : BasicFlagsNVPTXInst<(outs t.RC:$dst),
2001-
(ins ADDR:$addr, t.RC:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
2003+
(ins ADDR:$addr, t.RC:$b, t.Imm:$c),
2004+
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
20022005
asm_str>;
20032006

20042007
def _ii : BasicFlagsNVPTXInst<(outs t.RC:$dst),
2005-
(ins ADDR:$addr, t.Imm:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
2008+
(ins ADDR:$addr, t.Imm:$b, t.Imm:$c),
2009+
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
20062010
asm_str>;
20072011
}
20082012

@@ -2200,6 +2204,37 @@ defm INT_PTX_SATOM_MIN : ATOM2_minmax_impl<"min">;
22002204
defm INT_PTX_SATOM_OR : ATOM2_bitwise_impl<"or">;
22012205
defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;
22022206

2207+
// atom.*.b128
2208+
2209+
let mayLoad = true, mayStore = true, hasSideEffects = true,
2210+
Predicates = [hasAtomSwap128] in {
2211+
def ATOM_CAS_B128 :
2212+
NVPTXInst<
2213+
(outs B64:$dst0, B64:$dst1),
2214+
(ins ADDR:$addr, B64:$cmp0, B64:$cmp1, B64:$swap0, B64:$swap1,
2215+
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
2216+
"{{\n\t"
2217+
".reg .b128 cmp, swap, dst;\n\t"
2218+
"mov.b128 cmp, {$cmp0, $cmp1};\n\t"
2219+
"mov.b128 swap, {$swap0, $swap1};\n\t"
2220+
"atom${sem:sem}${scope:scope}${addsp:addsp}.cas.b128 dst, [$addr], cmp, swap;\n\t"
2221+
"mov.b128 {$dst0, $dst1}, dst;\n\t"
2222+
"}}">;
2223+
2224+
def ATOM_EXCH_B128 :
2225+
NVPTXInst<
2226+
(outs B64:$dst0, B64:$dst1),
2227+
(ins ADDR:$addr, B64:$amt0, B64:$amt1,
2228+
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
2229+
"{{\n\t"
2230+
".reg .b128 amt, dst;\n\t"
2231+
"mov.b128 amt, {$amt0, $amt1};\n\t"
2232+
"atom${sem:sem}${scope:scope}${addsp:addsp}.exch.b128 dst, [$addr], amt;\n\t"
2233+
"mov.b128 {$dst0, $dst1}, dst;\n\t"
2234+
"}}">;
2235+
}
2236+
2237+
22032238
//-----------------------------------
22042239
// Support for ldu on sm_20 or later
22052240
//-----------------------------------

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
8282
bool hasAtomBitwise64() const { return SmVersion >= 32; }
8383
bool hasAtomMinMax64() const { return SmVersion >= 32; }
8484
bool hasAtomCas16() const { return SmVersion >= 70 && PTXVersion >= 63; }
85+
bool hasAtomSwap128() const { return SmVersion >= 90 && PTXVersion >= 83; }
8586
bool hasClusters() const { return SmVersion >= 90 && PTXVersion >= 78; }
8687
bool hasLDG() const { return SmVersion >= 32; }
8788
bool hasHWROT32() const { return SmVersion >= 32; }

llvm/test/CodeGen/NVPTX/atomicrmw-expand.err.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44
; CHECK: error: unsupported cmpxchg
55
; CHECK: error: unsupported cmpxchg
66
; CHECK: error: unsupported cmpxchg
7-
define void @bitwise_i128(ptr %0, i128 %1) {
7+
define void @bitwise_i256(ptr %0, i256 %1) {
88
entry:
9-
%2 = atomicrmw and ptr %0, i128 %1 monotonic, align 16
10-
%3 = atomicrmw or ptr %0, i128 %1 monotonic, align 16
11-
%4 = atomicrmw xor ptr %0, i128 %1 monotonic, align 16
12-
%5 = atomicrmw xchg ptr %0, i128 %1 monotonic, align 16
9+
%2 = atomicrmw and ptr %0, i256 %1 monotonic, align 16
10+
%3 = atomicrmw or ptr %0, i256 %1 monotonic, align 16
11+
%4 = atomicrmw xor ptr %0, i256 %1 monotonic, align 16
12+
%5 = atomicrmw xchg ptr %0, i256 %1 monotonic, align 16
1313
ret void
1414
}
1515

1616
; CHECK: error: unsupported cmpxchg
1717
; CHECK: error: unsupported cmpxchg
1818
; CHECK: error: unsupported cmpxchg
1919
; CHECK: error: unsupported cmpxchg
20-
define void @minmax_i128(ptr %0, i128 %1) {
20+
define void @minmax_i256(ptr %0, i256 %1) {
2121
entry:
22-
%2 = atomicrmw min ptr %0, i128 %1 monotonic, align 16
23-
%3 = atomicrmw max ptr %0, i128 %1 monotonic, align 16
24-
%4 = atomicrmw umin ptr %0, i128 %1 monotonic, align 16
25-
%5 = atomicrmw umax ptr %0, i128 %1 monotonic, align 16
22+
%2 = atomicrmw min ptr %0, i256 %1 monotonic, align 16
23+
%3 = atomicrmw max ptr %0, i256 %1 monotonic, align 16
24+
%4 = atomicrmw umin ptr %0, i256 %1 monotonic, align 16
25+
%5 = atomicrmw umax ptr %0, i256 %1 monotonic, align 16
2626
ret void
2727
}

0 commit comments

Comments
 (0)