@@ -38,6 +38,27 @@ def AS_match {
3838 }];
3939}
4040
41+ multiclass nvvm_ternary_atomic_op_scoped<SDPatternOperator frag> {
42+ defvar frag_pat = (frag node:$ptr, node:$cmp, node:$val);
43+ def NAME#_cta: PatFrag<!setdagop(frag_pat, ops),
44+ (!cast<SDPatternOperator>(NAME) node:$ptr, node:$cmp, node:$val), [{
45+ return Scopes[cast<MemSDNode>(N)->getSyncScopeID()] == NVPTX::Scope::Block;
46+ }]>;
47+ def NAME#_cluster : PatFrag<!setdagop(frag_pat, ops),
48+ (!cast<SDPatternOperator>(NAME) node:$ptr, node:$cmp, node:$val), [{
49+ return Scopes[cast<MemSDNode>(N)->getSyncScopeID()] == NVPTX::Scope::Cluster;
50+ }]>;
51+ def NAME#_gpu: PatFrag<!setdagop(frag_pat, ops),
52+ (!cast<SDPatternOperator>(NAME) node:$ptr, node:$cmp, node:$val), [{
53+ return Scopes[cast<MemSDNode>(N)->getSyncScopeID()] == NVPTX::Scope::Device;
54+ }]>;
55+ def NAME#_sys: PatFrag<!setdagop(frag_pat, ops),
56+ (!cast<SDPatternOperator>(NAME) node:$ptr, node:$cmp, node:$val), [{
57+ return Scopes[cast<MemSDNode>(N)->getSyncScopeID()] == NVPTX::Scope::System;
58+ }]>;
59+ }
60+
61+
4162// A node that will be replaced with the current PTX version.
4263class PTX {
4364 SDNodeXForm PTXVerXform = SDNodeXForm<imm, [{
@@ -2022,40 +2043,41 @@ multiclass F_ATOMIC_2_NEG<ValueType regT, NVPTXRegClass regclass, string SpaceSt
20222043
20232044// has 3 operands
20242045multiclass F_ATOMIC_3_imp<ValueType ptrT, NVPTXRegClass ptrclass,
2025- ValueType regT, NVPTXRegClass regclass, string SemStr,
2026- string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp,
2027- Operand IMMType, list<Predicate> Pred> {
2046+ ValueType regT, NVPTXRegClass regclass, string SemStr,
2047+ string ScopeStr, string SpaceStr, string TypeStr, string OpcStr,
2048+ PatFrag IntOp, Operand IMMType, list<Predicate> Pred> {
20282049 let mayLoad = 1, mayStore = 1, hasSideEffects = 1 in {
20292050 def reg : NVPTXInst<(outs regclass:$dst),
20302051 (ins ptrclass:$addr, regclass:$b, regclass:$c),
2031- !strconcat("atom", SemStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
2052+ !strconcat("atom", SemStr, ScopeStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
20322053 [(set (regT regclass:$dst), (IntOp (ptrT ptrclass:$addr), (regT regclass:$b), (regT regclass:$c)))]>,
20332054 Requires<Pred>;
20342055
20352056 def imm1 : NVPTXInst<(outs regclass:$dst),
20362057 (ins ptrclass:$addr, IMMType:$b, regclass:$c),
2037- !strconcat("atom", SemStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
2058+ !strconcat("atom", SemStr, ScopeStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
20382059 [(set (regT regclass:$dst), (IntOp (ptrT ptrclass:$addr), imm:$b, (regT regclass:$c)))]>,
20392060 Requires<Pred>;
20402061
20412062 def imm2 : NVPTXInst<(outs regclass:$dst),
20422063 (ins ptrclass:$addr, regclass:$b, IMMType:$c),
2043- !strconcat("atom", SemStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;", ""),
2064+ !strconcat("atom", SemStr, ScopeStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;", ""),
20442065 [(set (regT regclass:$dst), (IntOp (ptrT ptrclass:$addr), (regT regclass:$b), imm:$c))]>,
20452066 Requires<Pred>;
20462067
20472068 def imm3 : NVPTXInst<(outs regclass:$dst),
20482069 (ins ptrclass:$addr, IMMType:$b, IMMType:$c),
2049- !strconcat("atom", SemStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
2070+ !strconcat("atom", SemStr, ScopeStr, SpaceStr, OpcStr, TypeStr, " \t$dst, [$addr], $b, $c;"),
20502071 [(set (regT regclass:$dst), (IntOp (ptrT ptrclass:$addr), imm:$b, imm:$c))]>,
20512072 Requires<Pred>;
20522073 }
20532074}
2054- multiclass F_ATOMIC_3<ValueType regT, NVPTXRegClass regclass, string SemStr, string SpaceStr,
2055- string TypeStr, string OpcStr, PatFrag IntOp, Operand IMMType, list<Predicate> Pred = []> {
2056- defm p32 : F_ATOMIC_3_imp<i32, Int32Regs, regT, regclass, SemStr, SpaceStr, TypeStr,
2075+ multiclass F_ATOMIC_3<ValueType regT, NVPTXRegClass regclass, string SemStr, string ScopeStr,
2076+ string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp, Operand IMMType,
2077+ list<Predicate> Pred = []> {
2078+ defm p32 : F_ATOMIC_3_imp<i32, Int32Regs, regT, regclass, SemStr, ScopeStr, SpaceStr, TypeStr,
20572079 OpcStr, IntOp, IMMType, Pred>;
2058- defm p64 : F_ATOMIC_3_imp<i64, Int64Regs, regT, regclass, SemStr, SpaceStr, TypeStr,
2080+ defm p64 : F_ATOMIC_3_imp<i64, Int64Regs, regT, regclass, SemStr, ScopeStr, SpaceStr, TypeStr,
20592081 OpcStr, IntOp, IMMType, Pred>;
20602082}
20612083
@@ -2469,10 +2491,12 @@ foreach size = ["i16", "i32", "i64"] in {
24692491// ".cas", atomic_cmp_swap_i32_acquire_global, i32imm,
24702492// [hasSM<70>, hasPTX<63>]>
24712493multiclass INT_PTX_ATOM_CAS<string atomic_cmp_swap_pat, string type,
2472- string order, string addrspace, list<Predicate> preds>
2494+ string order, string scope, string addrspace,
2495+ list<Predicate> preds>
24732496 : F_ATOMIC_3<!cast<ValueType>("i"#type),
24742497 !cast<NVPTXRegClass>("Int"#type#"Regs"),
24752498 order,
2499+ scope,
24762500 addrspace,
24772501 ".b"#type,
24782502 ".cas",
@@ -2487,26 +2511,35 @@ foreach size = ["32", "64"] in {
24872511 defvar cas_addrspace_string = !if(!eq(addrspace, "generic"), "", "."#addrspace);
24882512 foreach order = ["acquire", "release", "acq_rel", "monotonic"] in {
24892513 defvar cas_order_string = !if(!eq(order, "monotonic"), ".relaxed", "."#order);
2514+ defvar atomic_cmp_swap_pat = !cast<PatFrag>("atomic_cmp_swap_i"#size#_#order#_#addrspace);
2515+ defm atomic_cmp_swap_i#size#_#order#_#addrspace: nvvm_ternary_atomic_op_scoped<atomic_cmp_swap_pat>;
2516+
2517+ foreach scope = ["cta", "cluster", "gpu", "sys"] in {
2518+ defm INT_PTX_ATOM_CAS_#size#_#order#addrspace#scope
2519+ : INT_PTX_ATOM_CAS<"atomic_cmp_swap_i"#size#_#order#_#addrspace#_#scope, size,
2520+ cas_order_string, "."#scope, cas_addrspace_string,
2521+ [hasSM<70>, hasPTX<63>]>;
2522+ }
24902523 // Note that AtomicExpand will convert cmpxchg seq_cst to a cmpxchg monotonic with fences around it.
24912524 // Memory orders are only supported for SM70+, PTX63+- so we have two sets of instruction definitions-
24922525 // for SM70+, and "old" ones which lower to "atom.cas", for earlier archs.
24932526 defm INT_PTX_ATOM_CAS_#size#_#order#addrspace
24942527 : INT_PTX_ATOM_CAS<"atomic_cmp_swap_i"#size#_#order#_#addrspace, size,
2495- cas_order_string, cas_addrspace_string,
2528+ cas_order_string, "", cas_addrspace_string,
24962529 [hasSM<70>, hasPTX<63>]>;
24972530 defm INT_PTX_ATOM_CAS_#size#_#order#_old#addrspace
24982531 : INT_PTX_ATOM_CAS<"atomic_cmp_swap_i"#size#_#order#_#addrspace, size,
2499- "", cas_addrspace_string, []>;
2532+ "", "", cas_addrspace_string, []>;
25002533 }
25012534 }
25022535}
25032536
25042537// Note that 16-bit CAS support in PTX is emulated.
2505- defm INT_PTX_ATOM_CAS_G_16 : F_ATOMIC_3<i16, Int16Regs, "", ".global", ".b16", ".cas",
2538+ defm INT_PTX_ATOM_CAS_G_16 : F_ATOMIC_3<i16, Int16Regs, "", "", " .global", ".b16", ".cas",
25062539 atomic_cmp_swap_i16_global, i16imm, [hasSM<70>, hasPTX<63>]>;
2507- defm INT_PTX_ATOM_CAS_S_16 : F_ATOMIC_3<i16, Int16Regs, "", ".shared", ".b16", ".cas",
2540+ defm INT_PTX_ATOM_CAS_S_16 : F_ATOMIC_3<i16, Int16Regs, "", "", " .shared", ".b16", ".cas",
25082541 atomic_cmp_swap_i16_shared, i16imm, [hasSM<70>, hasPTX<63>]>;
2509- defm INT_PTX_ATOM_CAS_GEN_16 : F_ATOMIC_3<i16, Int16Regs, "", "", ".b16", ".cas",
2542+ defm INT_PTX_ATOM_CAS_GEN_16 : F_ATOMIC_3<i16, Int16Regs, "", "", "", " .b16", ".cas",
25102543 atomic_cmp_swap_i16_generic, i16imm, [hasSM<70>, hasPTX<63>]>;
25112544
25122545// Support for scoped atomic operations. Matches
0 commit comments