Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
}
break;
}
case NVPTXISD::ATOMIC_CMP_SWAP_B128:
case NVPTXISD::ATOMIC_SWAP_B128:
selectAtomicSwap128(N);
return;
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
Expand Down Expand Up @@ -2337,3 +2341,28 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
}
}
}

void NVPTXDAGToDAGISel::selectAtomicSwap128(SDNode *N) {
MemSDNode *AN = cast<MemSDNode>(N);
SDLoc dl(N);

const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
SmallVector<SDValue, 5> Ops{Base, Offset};
Ops.append(N->op_begin() + 2, N->op_end());
Ops.append({
getI32Imm(getMemOrder(AN), dl),
getI32Imm(getAtomicScope(AN), dl),
getI32Imm(getAddrSpace(AN), dl),
});

assert(N->getOpcode() == NVPTXISD::ATOMIC_CMP_SWAP_B128 ||
N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128);
unsigned Opcode = N->getOpcode() == NVPTXISD::ATOMIC_SWAP_B128
? NVPTX::ATOM_EXCH_B128
: NVPTX::ATOM_CAS_B128;

auto *ATOM = CurDAG->getMachineNode(Opcode, dl, N->getVTList(), Ops);
CurDAG->setNodeMemRefs(ATOM, AN->getMemOperand());

ReplaceNode(N, ATOM);
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool IsIm2Col = false);
void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
void SelectTcgen05St(SDNode *N, bool hasOffset = false);
void selectAtomicSwap128(SDNode *N);

inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
Expand Down
75 changes: 68 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,15 +1036,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);

setOperationAction(ISD::ATOMIC_LOAD_SUB, {MVT::i32, MVT::i64}, Expand);
// No FPOW or FREM in PTX.

// atom.b128 is legal in PTX but since we don't represent i128 as a legal
// type, we need to custom lower it.
setOperationAction({ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, MVT::i128,
Custom);

// Now deduce the information based on the above mentioned
// actions
computeRegisterProperties(STI.getRegisterInfo());

// PTX support for 16-bit CAS is emulated. Only use 32+
setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
setMaxAtomicSizeInBitsSupported(64);
setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
setMaxDivRemBitWidthSupported(64);

// Custom lowering for tcgen05.ld vector operands
Expand Down Expand Up @@ -1077,6 +1081,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
case NVPTXISD::FIRST_NUMBER:
break;

MAKE_CASE(NVPTXISD::ATOMIC_CMP_SWAP_B128)
MAKE_CASE(NVPTXISD::ATOMIC_SWAP_B128)
MAKE_CASE(NVPTXISD::RET_GLUE)
MAKE_CASE(NVPTXISD::DeclareArrayParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
Expand Down Expand Up @@ -6236,6 +6242,49 @@ static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
Results.push_back(Res);
}

static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
const NVPTXSubtarget &STI,
SmallVectorImpl<SDValue> &Results) {
assert(N->getValueType(0) == MVT::i128 &&
"Custom lowering for atomic128 only supports i128");

AtomicSDNode *AN = cast<AtomicSDNode>(N);
SDLoc dl(N);

if (!STI.hasAtomSwap128()) {
DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
DAG.getMachineFunction().getFunction(),
"Support for b128 atomics introduced in PTX ISA version 8.3 and "
"requires target sm_90.",
dl.getDebugLoc()));

Results.push_back(DAG.getUNDEF(MVT::i128));
Results.push_back(AN->getOperand(0)); // Chain
return;
}

SmallVector<SDValue, 6> Ops;
Ops.push_back(AN->getOperand(0)); // Chain
Ops.push_back(AN->getOperand(1)); // Ptr
for (const auto &Op : AN->ops().drop_front(2)) {
// Low part
Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
DAG.getIntPtrConstant(0, dl)));
// High part
Ops.push_back(DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i64, Op,
DAG.getIntPtrConstant(1, dl)));
}
unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
? NVPTXISD::ATOMIC_SWAP_B128
: NVPTXISD::ATOMIC_CMP_SWAP_B128;
SDVTList Tys = DAG.getVTList(MVT::i64, MVT::i64, MVT::Other);
SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, Tys, Ops, MVT::i128,
AN->getMemOperand());
Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i128,
{Result.getValue(0), Result.getValue(1)}));
Results.push_back(Result.getValue(2));
}

void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
Expand All @@ -6256,6 +6305,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case NVPTXISD::ProxyReg:
replaceProxyReg(N, DAG, *this, Results);
return;
case ISD::ATOMIC_CMP_SWAP:
case ISD::ATOMIC_SWAP:
replaceAtomicSwap128(N, DAG, STI, Results);
return;
}
}

Expand All @@ -6280,16 +6333,19 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
}

assert(Ty->isIntegerTy() && "Ty should be integer at this point");
auto ITy = cast<llvm::IntegerType>(Ty);
const unsigned BitWidth = cast<IntegerType>(Ty)->getBitWidth();

switch (AI->getOperation()) {
default:
return AtomicExpansionKind::CmpXChg;
case AtomicRMWInst::BinOp::Xchg:
if (BitWidth == 128)
return AtomicExpansionKind::None;
LLVM_FALLTHROUGH;
case AtomicRMWInst::BinOp::And:
case AtomicRMWInst::BinOp::Or:
case AtomicRMWInst::BinOp::Xor:
case AtomicRMWInst::BinOp::Xchg:
switch (ITy->getBitWidth()) {
switch (BitWidth) {
case 8:
case 16:
return AtomicExpansionKind::CmpXChg;
Expand All @@ -6299,6 +6355,8 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
if (STI.hasAtomBitwise64())
return AtomicExpansionKind::None;
return AtomicExpansionKind::CmpXChg;
case 128:
return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
}
Expand All @@ -6308,7 +6366,7 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
case AtomicRMWInst::BinOp::Min:
case AtomicRMWInst::BinOp::UMax:
case AtomicRMWInst::BinOp::UMin:
switch (ITy->getBitWidth()) {
switch (BitWidth) {
case 8:
case 16:
return AtomicExpansionKind::CmpXChg;
Expand All @@ -6318,17 +6376,20 @@ NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
if (STI.hasAtomMinMax64())
return AtomicExpansionKind::None;
return AtomicExpansionKind::CmpXChg;
case 128:
return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
}
case AtomicRMWInst::BinOp::UIncWrap:
case AtomicRMWInst::BinOp::UDecWrap:
switch (ITy->getBitWidth()) {
switch (BitWidth) {
case 32:
return AtomicExpansionKind::None;
case 8:
case 16:
case 64:
case 128:
return AtomicExpansionKind::CmpXChg;
default:
llvm_unreachable("unsupported width encountered");
Expand Down
12 changes: 11 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,17 @@ enum NodeType : unsigned {
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,

FIRST_MEMORY_OPCODE,
LoadV2 = FIRST_MEMORY_OPCODE,

/// These nodes are used to lower atomic instructions with i128 type. They are
/// similar to the generic nodes, but the input and output values are split
/// into two 64-bit values.
/// ValLo, ValHi, OUTCHAIN = ATOMIC_CMP_SWAP_B128(INCHAIN, ptr, cmpLo, cmpHi,
/// swapLo, swapHi)
/// ValLo, ValHi, OUTCHAIN = ATOMIC_SWAP_B128(INCHAIN, ptr, amtLo, amtHi)
ATOMIC_CMP_SWAP_B128 = FIRST_MEMORY_OPCODE,
ATOMIC_SWAP_B128,

LoadV2,
LoadV4,
LoadV8,
LDUV2, // LDU.v2
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def hasAtomAddF64 : Predicate<"Subtarget->hasAtomAddF64()">;
def hasAtomScope : Predicate<"Subtarget->hasAtomScope()">;
def hasAtomBitwise64 : Predicate<"Subtarget->hasAtomBitwise64()">;
def hasAtomMinMax64 : Predicate<"Subtarget->hasAtomMinMax64()">;
def hasAtomSwap128 : Predicate<"Subtarget->hasAtomSwap128()">;
def hasClusters : Predicate<"Subtarget->hasClusters()">;
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;
Expand Down
43 changes: 39 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1990,19 +1990,23 @@ multiclass F_ATOMIC_3<RegTyInfo t, string op_str, SDPatternOperator op, SDNode a

let mayLoad = 1, mayStore = 1, hasSideEffects = 1 in {
def _rr : BasicFlagsNVPTXInst<(outs t.RC:$dst),
(ins ADDR:$addr, t.RC:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
(ins ADDR:$addr, t.RC:$b, t.RC:$c),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;

def _ir : BasicFlagsNVPTXInst<(outs t.RC:$dst),
(ins ADDR:$addr, t.Imm:$b, t.RC:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
(ins ADDR:$addr, t.Imm:$b, t.RC:$c),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;

def _ri : BasicFlagsNVPTXInst<(outs t.RC:$dst),
(ins ADDR:$addr, t.RC:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
(ins ADDR:$addr, t.RC:$b, t.Imm:$c),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;

def _ii : BasicFlagsNVPTXInst<(outs t.RC:$dst),
(ins ADDR:$addr, t.Imm:$b, t.Imm:$c), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
(ins ADDR:$addr, t.Imm:$b, t.Imm:$c),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
asm_str>;
}

Expand Down Expand Up @@ -2200,6 +2204,37 @@ defm INT_PTX_SATOM_MIN : ATOM2_minmax_impl<"min">;
defm INT_PTX_SATOM_OR : ATOM2_bitwise_impl<"or">;
defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;

// atom.*.b128

let mayLoad = true, mayStore = true, hasSideEffects = true,
Predicates = [hasAtomSwap128] in {
def ATOM_CAS_B128 :
NVPTXInst<
(outs B64:$dst0, B64:$dst1),
(ins ADDR:$addr, B64:$cmp0, B64:$cmp1, B64:$swap0, B64:$swap1,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
"{{\n\t"
".reg .b128 cmp, swap, dst;\n\t"
"mov.b128 cmp, {$cmp0, $cmp1};\n\t"
"mov.b128 swap, {$swap0, $swap1};\n\t"
"atom${sem:sem}${scope:scope}${addsp:addsp}.cas.b128 dst, [$addr], cmp, swap;\n\t"
"mov.b128 {$dst0, $dst1}, dst;\n\t"
"}}">;

def ATOM_EXCH_B128 :
NVPTXInst<
(outs B64:$dst0, B64:$dst1),
(ins ADDR:$addr, B64:$amt0, B64:$amt1,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp),
"{{\n\t"
".reg .b128 amt, dst;\n\t"
"mov.b128 amt, {$amt0, $amt1};\n\t"
"atom${sem:sem}${scope:scope}${addsp:addsp}.exch.b128 dst, [$addr], amt;\n\t"
"mov.b128 {$dst0, $dst1}, dst;\n\t"
"}}">;
}


//-----------------------------------
// Support for ldu on sm_20 or later
//-----------------------------------
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasAtomBitwise64() const { return SmVersion >= 32; }
bool hasAtomMinMax64() const { return SmVersion >= 32; }
bool hasAtomCas16() const { return SmVersion >= 70 && PTXVersion >= 63; }
bool hasAtomSwap128() const { return SmVersion >= 90 && PTXVersion >= 83; }
bool hasClusters() const { return SmVersion >= 90 && PTXVersion >= 78; }
bool hasLDG() const { return SmVersion >= 32; }
bool hasHWROT32() const { return SmVersion >= 32; }
Expand Down
20 changes: 10 additions & 10 deletions llvm/test/CodeGen/NVPTX/atomicrmw-expand.err.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@
; CHECK: error: unsupported cmpxchg
; CHECK: error: unsupported cmpxchg
; CHECK: error: unsupported cmpxchg
define void @bitwise_i128(ptr %0, i128 %1) {
define void @bitwise_i256(ptr %0, i256 %1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have two tests, one for i128 and one for i256, and the i256 should always fail, but the i128 one should fail or pass depending on whether sm < 90 or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already coverage of i128 failing when unsupported in atomics-b128.ll. I think there is any reason to add that here as well.

entry:
%2 = atomicrmw and ptr %0, i128 %1 monotonic, align 16
%3 = atomicrmw or ptr %0, i128 %1 monotonic, align 16
%4 = atomicrmw xor ptr %0, i128 %1 monotonic, align 16
%5 = atomicrmw xchg ptr %0, i128 %1 monotonic, align 16
%2 = atomicrmw and ptr %0, i256 %1 monotonic, align 16
%3 = atomicrmw or ptr %0, i256 %1 monotonic, align 16
%4 = atomicrmw xor ptr %0, i256 %1 monotonic, align 16
%5 = atomicrmw xchg ptr %0, i256 %1 monotonic, align 16
ret void
}

; CHECK: error: unsupported cmpxchg
; CHECK: error: unsupported cmpxchg
; CHECK: error: unsupported cmpxchg
; CHECK: error: unsupported cmpxchg
define void @minmax_i128(ptr %0, i128 %1) {
define void @minmax_i256(ptr %0, i256 %1) {
entry:
%2 = atomicrmw min ptr %0, i128 %1 monotonic, align 16
%3 = atomicrmw max ptr %0, i128 %1 monotonic, align 16
%4 = atomicrmw umin ptr %0, i128 %1 monotonic, align 16
%5 = atomicrmw umax ptr %0, i128 %1 monotonic, align 16
%2 = atomicrmw min ptr %0, i256 %1 monotonic, align 16
%3 = atomicrmw max ptr %0, i256 %1 monotonic, align 16
%4 = atomicrmw umin ptr %0, i256 %1 monotonic, align 16
%5 = atomicrmw umax ptr %0, i256 %1 monotonic, align 16
ret void
}
Loading