Skip to content

Commit 91cd2db

Browse files
durga4githubLukacma
authored andcommitted
[NVPTX] Add missing mbarrier intrinsics (llvm#164864)
This patch adds a few more mbarrier intrinsics, completing support for all the mbarrier variants up to Blackwell architecture. * Docs are updated in NVPTXUsage.rst. * lit tests are added for all the variants. * lit tests are verified with PTXAS from CUDA-12.8 toolkit. Signed-off-by: Durgadoss R <[email protected]>
1 parent c6522c4 commit 91cd2db

File tree

11 files changed

+1509
-1
lines changed

11 files changed

+1509
-1
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 444 additions & 0 deletions
Large diffs are not rendered by default.

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,73 @@ let IntrProperties = [IntrConvergent, IntrNoCallback] in {
18661866
def int_nvvm_mbarrier_pending_count : NVVMBuiltin,
18671867
Intrinsic<[llvm_i32_ty], [llvm_i64_ty], [IntrNoMem, IntrConvergent, IntrNoCallback]>;
18681868

1869+
// mbarrier.{expect_tx/complete_tx}
1870+
foreach op = ["expect_tx", "complete_tx"] in {
1871+
foreach scope = ["scope_cta", "scope_cluster"] in {
1872+
foreach space = ["space_cta", "space_cluster"] in {
1873+
defvar suffix = StrJoin<"_", [op, scope, space]>.ret;
1874+
defvar mbar_addr_ty = !if(!eq(space, "space_cta"),
1875+
llvm_shared_ptr_ty, llvm_shared_cluster_ptr_ty);
1876+
1877+
def int_nvvm_mbarrier_ # suffix :
1878+
Intrinsic<[], [mbar_addr_ty, llvm_i32_ty],
1879+
[IntrConvergent, IntrArgMemOnly, IntrNoCallback]>;
1880+
} // space
1881+
} // scope
1882+
} // op
1883+
1884+
// mbarrier.arrive and mbarrier.arrive.expect_tx
1885+
// mbarrier.arrive_drop and mbarrier.arrive_drop.expect_tx
1886+
foreach op = ["arrive", "arrive_expect_tx",
1887+
"arrive_drop", "arrive_drop_expect_tx"] in {
1888+
foreach scope = ["scope_cta", "scope_cluster"] in {
1889+
foreach space = ["space_cta", "space_cluster"] in {
1890+
defvar suffix = StrJoin<"_", [scope, space]>.ret;
1891+
defvar mbar_addr_ty = !if(!eq(space, "space_cta"),
1892+
llvm_shared_ptr_ty, llvm_shared_cluster_ptr_ty);
1893+
defvar args_ty = [mbar_addr_ty, // mbar_address_ptr
1894+
llvm_i32_ty]; // tx-count
1895+
1896+
// mbarriers in shared_cluster space cannot return any value.
1897+
defvar mbar_ret_ty = !if(!eq(space, "space_cta"),
1898+
[llvm_i64_ty], []<LLVMType>);
1899+
1900+
def int_nvvm_mbarrier_ # op # "_" # suffix:
1901+
Intrinsic<mbar_ret_ty, args_ty,
1902+
[IntrConvergent, IntrNoCallback]>;
1903+
def int_nvvm_mbarrier_ # op # "_relaxed_" # suffix :
1904+
Intrinsic<mbar_ret_ty, args_ty,
1905+
[IntrConvergent, IntrArgMemOnly, IntrNoCallback]>;
1906+
} // space
1907+
} // scope
1908+
} // op
1909+
1910+
// mbarrier.{test_wait and try_wait}
1911+
foreach op = ["test_wait", "try_wait"] in {
1912+
foreach scope = ["scope_cta", "scope_cluster"] in {
1913+
foreach parity = [true, false] in {
1914+
foreach time_limit = !if(!eq(op, "try_wait"), [true, false], [false]) in {
1915+
defvar base_args = [llvm_shared_ptr_ty]; // mbar_ptr
1916+
defvar parity_args = !if(parity, [llvm_i32_ty], [llvm_i64_ty]);
1917+
defvar tl_args = !if(time_limit, [llvm_i32_ty], []<LLVMType>);
1918+
defvar args = !listconcat(base_args, parity_args, tl_args);
1919+
defvar tmp_op = StrJoin<"_", [op,
1920+
!if(parity, "parity", ""),
1921+
!if(time_limit, "tl", "")]>.ret;
1922+
defvar suffix = StrJoin<"_", [scope, "space_cta"]>.ret;
1923+
1924+
def int_nvvm_mbarrier_ # tmp_op # "_" # suffix :
1925+
Intrinsic<[llvm_i1_ty], args,
1926+
[IntrConvergent, NoCapture<ArgIndex<0>>, IntrNoCallback]>;
1927+
def int_nvvm_mbarrier_ # tmp_op # "_relaxed_" # suffix :
1928+
Intrinsic<[llvm_i1_ty], args,
1929+
[IntrConvergent, NoCapture<ArgIndex<0>>, IntrNoCallback,
1930+
IntrArgMemOnly, IntrReadMem]>;
1931+
} // tl
1932+
} // parity
1933+
} // scope
1934+
} // op
1935+
18691936
// Generated within nvvm. Use for ldu on sm_20 or later. Second arg is the
18701937
// pointer's alignment.
18711938
let IntrProperties = [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture<ArgIndex<0>>] in {
@@ -3000,4 +3067,4 @@ foreach sp = [0, 1] in {
30003067
}
30013068
}
30023069

3003-
} // let TargetPrefix = "nvvm"
3070+
} // let TargetPrefix = "nvvm"

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,161 @@ let Predicates = [hasPTX<70>, hasSM<80>] in {
10821082
"mbarrier.pending_count.b64",
10831083
[(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
10841084
}
1085+
1086+
class MBAR_UTIL<string op, string scope,
1087+
string space = "", string sem = "",
1088+
bit tl = 0, bit parity = 0> {
1089+
// The mbarrier instructions in PTX ISA are of the general form:
1090+
// mbarrier.op.semantics.scope.space.b64 arg1, arg2 ...
1091+
// where:
1092+
// op -> arrive, expect_tx, complete_tx, arrive.expect_tx etc.
1093+
// semantics -> acquire, release, relaxed (default depends on the op)
1094+
// scope -> cta or cluster (default is cta-scope)
1095+
// space -> shared::cta or shared::cluster (default is shared::cta)
1096+
//
1097+
// The 'semantics' and 'scope' go together. If one is specified,
1098+
// then the other _must_ be specified. For example:
1099+
// (A) mbarrier.arrive <args> (valid, release and cta are default)
1100+
// (B) mbarrier.arrive.release.cta <args> (valid, sem/scope mentioned explicitly)
1101+
// (C) mbarrier.arrive.release <args> (invalid, needs scope)
1102+
// (D) mbarrier.arrive.cta <args> (invalid, needs order)
1103+
//
1104+
// Wherever possible, we prefer form (A) to (B) since it is available
1105+
// from early PTX versions. In most cases, explicitly specifying the
1106+
// scope requires a later version of PTX.
1107+
string _scope_asm = !cond(
1108+
!eq(scope, "scope_cluster") : "cluster",
1109+
!eq(scope, "scope_cta") : !if(!empty(sem), "", "cta"),
1110+
true : scope);
1111+
string _space_asm = !cond(
1112+
!eq(space, "space_cta") : "shared",
1113+
!eq(space, "space_cluster") : "shared::cluster",
1114+
true : space);
1115+
1116+
string _parity = !if(parity, "parity", "");
1117+
string asm_str = StrJoin<".", ["mbarrier", op, _parity,
1118+
sem, _scope_asm, _space_asm, "b64"]>.ret;
1119+
1120+
string _intr_suffix = StrJoin<"_", [!subst(".", "_", op), _parity,
1121+
!if(tl, "tl", ""),
1122+
sem, scope, space]>.ret;
1123+
string intr_name = "int_nvvm_mbarrier_" # _intr_suffix;
1124+
1125+
// Predicate checks:
1126+
// These are used only for the "test_wait/try_wait" variants as they
1127+
// have evolved since sm80 and are complex. The predicates for the
1128+
// remaining instructions are straightforward and have already been
1129+
// applied directly.
1130+
Predicate _sm_pred = !cond(!or(
1131+
!eq(op, "try_wait"),
1132+
!eq(scope, "scope_cluster"),
1133+
!eq(sem, "relaxed")) : hasSM<90>,
1134+
true : hasSM<80>);
1135+
Predicate _ptx_pred = !cond(
1136+
!eq(sem, "relaxed") : hasPTX<86>,
1137+
!ne(_scope_asm, "") : hasPTX<80>,
1138+
!eq(op, "try_wait") : hasPTX<78>,
1139+
parity : hasPTX<71>,
1140+
true : hasPTX<70>);
1141+
list<Predicate> preds = [_ptx_pred, _sm_pred];
1142+
}
1143+
1144+
foreach op = ["expect_tx", "complete_tx"] in {
1145+
foreach scope = ["scope_cta", "scope_cluster"] in {
1146+
foreach space = ["space_cta", "space_cluster"] in {
1147+
defvar intr = !cast<Intrinsic>(MBAR_UTIL<op, scope, space>.intr_name);
1148+
defvar suffix = StrJoin<"_", [op, scope, space]>.ret;
1149+
def mbar_ # suffix : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$tx_count),
1150+
MBAR_UTIL<op, scope, space, "relaxed">.asm_str,
1151+
[(intr addr:$addr, i32:$tx_count)]>,
1152+
Requires<[hasPTX<80>, hasSM<90>]>;
1153+
} // space
1154+
} // scope
1155+
} // op
1156+
1157+
multiclass MBAR_ARR_INTR<string op, string scope, string sem,
1158+
list<Predicate> pred = []> {
1159+
// When either of sem or scope is non-default, both have to
1160+
// be explicitly specified. So, explicitly state that
1161+
// sem is `release` when scope is `cluster`.
1162+
defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1163+
"release", sem);
1164+
1165+
defvar asm_cta = MBAR_UTIL<op, scope, "space_cta", asm_sem>.asm_str;
1166+
defvar intr_cta = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1167+
"space_cta", sem>.intr_name);
1168+
1169+
defvar asm_cluster = MBAR_UTIL<op, scope, "space_cluster", asm_sem>.asm_str;
1170+
defvar intr_cluster = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1171+
"space_cluster", sem>.intr_name);
1172+
1173+
def _CTA : NVPTXInst<(outs B64:$state),
1174+
(ins ADDR:$addr, B32:$tx_count),
1175+
asm_cta # " $state, [$addr], $tx_count;",
1176+
[(set i64:$state, (intr_cta addr:$addr, i32:$tx_count))]>,
1177+
Requires<pred>;
1178+
def _CLUSTER : NVPTXInst<(outs),
1179+
(ins ADDR:$addr, B32:$tx_count),
1180+
asm_cluster # " _, [$addr], $tx_count;",
1181+
[(intr_cluster addr:$addr, i32:$tx_count)]>,
1182+
Requires<pred>;
1183+
}
1184+
foreach op = ["arrive", "arrive.expect_tx",
1185+
"arrive_drop", "arrive_drop.expect_tx"] in {
1186+
foreach scope = ["scope_cta", "scope_cluster"] in {
1187+
defvar suffix = !subst(".", "_", op) # scope;
1188+
defm mbar_ # suffix # _release : MBAR_ARR_INTR<op, scope, "", [hasPTX<80>, hasSM<90>]>;
1189+
defm mbar_ # suffix # _relaxed : MBAR_ARR_INTR<op, scope, "relaxed", [hasPTX<86>, hasSM<90>]>;
1190+
} // scope
1191+
} // op
1192+
1193+
multiclass MBAR_WAIT_INTR<string op, string scope, string sem, bit time_limit> {
1194+
// When either of sem or scope is non-default, both have to
1195+
// be explicitly specified. So, explicitly state that the
1196+
// semantics is `acquire` when the scope is `cluster`.
1197+
defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1198+
"acquire", sem);
1199+
1200+
defvar asm_parity = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1201+
time_limit, 1>.asm_str;
1202+
defvar pred_parity = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1203+
time_limit, 1>.preds;
1204+
defvar intr_parity = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1205+
sem, time_limit, 1>.intr_name);
1206+
1207+
defvar asm_state = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1208+
time_limit>.asm_str;
1209+
defvar pred_state = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1210+
time_limit>.preds;
1211+
defvar intr_state = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1212+
sem, time_limit>.intr_name);
1213+
1214+
defvar ins_tl_dag = !if(time_limit, (ins B32:$tl), (ins));
1215+
defvar tl_suffix = !if(time_limit, ", $tl;", ";");
1216+
defvar intr_state_dag = !con((intr_state addr:$addr, i64:$state),
1217+
!if(time_limit, (intr_state i32:$tl), (intr_state)));
1218+
defvar intr_parity_dag = !con((intr_parity addr:$addr, i32:$phase),
1219+
!if(time_limit, (intr_parity i32:$tl), (intr_parity)));
1220+
1221+
def _STATE : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B64:$state), ins_tl_dag),
1222+
asm_state # " $res, [$addr], $state" # tl_suffix,
1223+
[(set i1:$res, intr_state_dag)]>,
1224+
Requires<pred_state>;
1225+
def _PARITY : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B32:$phase), ins_tl_dag),
1226+
asm_parity # " $res, [$addr], $phase" # tl_suffix,
1227+
[(set i1:$res, intr_parity_dag)]>,
1228+
Requires<pred_parity>;
1229+
}
1230+
foreach op = ["test_wait", "try_wait"] in {
1231+
foreach scope = ["scope_cta", "scope_cluster"] in {
1232+
foreach time_limit = !if(!eq(op, "try_wait"), [true, false], [false]) in {
1233+
defvar suffix = StrJoin<"_", [op, scope, !if(time_limit, "tl", "")]>.ret;
1234+
defm mbar_ # suffix # "_acquire" : MBAR_WAIT_INTR<op, scope, "", time_limit>;
1235+
defm mbar_ # suffix # "_relaxed" : MBAR_WAIT_INTR<op, scope, "relaxed", time_limit>;
1236+
} // time_limit
1237+
} // scope
1238+
} // op
1239+
10851240
//-----------------------------------
10861241
// Math Functions
10871242
//-----------------------------------

0 commit comments

Comments
 (0)