@@ -1082,6 +1082,161 @@ let Predicates = [hasPTX<70>, hasSM<80>] in {
10821082            "mbarrier.pending_count.b64",
10831083            [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
10841084}
1085+ 
1086+ class MBAR_UTIL<string op, string scope,
1087+                 string space = "", string sem = "",
1088+                 bit tl = 0, bit parity = 0> {
1089+   // The mbarrier instructions in PTX ISA are of the general form:
1090+   // mbarrier.op.semantics.scope.space.b64 arg1, arg2 ...
1091+   // where:
1092+   // op -> arrive, expect_tx, complete_tx, arrive.expect_tx etc.
1093+   // semantics -> acquire, release, relaxed (default depends on the op)
1094+   // scope -> cta or cluster (default is cta-scope)
1095+   // space -> shared::cta or shared::cluster (default is shared::cta)
1096+   //
1097+   // The 'semantics' and 'scope' go together. If one is specified,
1098+   // then the other _must_ be specified. For example:
1099+   // (A) mbarrier.arrive             <args> (valid, release and cta are default)
1100+   // (B) mbarrier.arrive.release.cta <args> (valid, sem/scope mentioned explicitly)
1101+   // (C) mbarrier.arrive.release     <args> (invalid, needs scope)
1102+   // (D) mbarrier.arrive.cta         <args> (invalid, needs order)
1103+   //
1104+   // Wherever possible, we prefer form (A) to (B) since it is available
1105+   // from early PTX versions. In most cases, explicitly specifying the
1106+   // scope requires a later version of PTX.
1107+   string _scope_asm = !cond(
1108+                       !eq(scope, "scope_cluster") : "cluster",
1109+                       !eq(scope, "scope_cta") : !if(!empty(sem), "", "cta"),
1110+                       true : scope);
1111+   string _space_asm = !cond(
1112+                       !eq(space, "space_cta") : "shared",
1113+                       !eq(space, "space_cluster") : "shared::cluster",
1114+                       true : space);
1115+ 
1116+   string _parity = !if(parity, "parity", "");
1117+   string asm_str = StrJoin<".", ["mbarrier", op, _parity,
1118+                                   sem, _scope_asm, _space_asm, "b64"]>.ret;
1119+ 
1120+   string _intr_suffix = StrJoin<"_", [!subst(".", "_", op), _parity,
1121+                                       !if(tl, "tl", ""),
1122+                                       sem, scope, space]>.ret;
1123+   string intr_name = "int_nvvm_mbarrier_" # _intr_suffix;
1124+ 
1125+   // Predicate checks:
1126+   // These are used only for the "test_wait/try_wait" variants as they
1127+   // have evolved since sm80 and are complex. The predicates for the
1128+   // remaining instructions are straightforward and have already been
1129+   // applied directly.
1130+   Predicate _sm_pred = !cond(!or(
1131+                        !eq(op, "try_wait"),
1132+                        !eq(scope, "scope_cluster"),
1133+                        !eq(sem, "relaxed")) : hasSM<90>,
1134+                        true : hasSM<80>);
1135+   Predicate _ptx_pred = !cond(
1136+                         !eq(sem, "relaxed") : hasPTX<86>,
1137+                         !ne(_scope_asm, "") : hasPTX<80>,
1138+                         !eq(op, "try_wait") : hasPTX<78>,
1139+                         parity : hasPTX<71>,
1140+                         true   : hasPTX<70>);
1141+   list<Predicate> preds = [_ptx_pred, _sm_pred];
1142+ }
1143+ 
1144+ foreach op = ["expect_tx", "complete_tx"] in {
1145+   foreach scope = ["scope_cta", "scope_cluster"] in {
1146+     foreach space = ["space_cta", "space_cluster"] in {
1147+       defvar intr = !cast<Intrinsic>(MBAR_UTIL<op, scope, space>.intr_name);
1148+       defvar suffix = StrJoin<"_", [op, scope, space]>.ret;
1149+       def mbar_ # suffix : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$tx_count),
1150+                            MBAR_UTIL<op, scope, space, "relaxed">.asm_str,
1151+                            [(intr addr:$addr, i32:$tx_count)]>,
1152+                            Requires<[hasPTX<80>, hasSM<90>]>;
1153+     } // space
1154+   } // scope
1155+ } // op
1156+ 
1157+ multiclass MBAR_ARR_INTR<string op, string scope, string sem,
1158+                          list<Predicate> pred = []> {
1159+   // When either of sem or scope is non-default, both have to
1160+   // be explicitly specified. So, explicitly state that
1161+   // sem is `release` when scope is `cluster`.
1162+   defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1163+                     "release", sem);
1164+ 
1165+   defvar asm_cta  = MBAR_UTIL<op, scope, "space_cta", asm_sem>.asm_str;
1166+   defvar intr_cta = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1167+                                      "space_cta", sem>.intr_name);
1168+ 
1169+   defvar asm_cluster  = MBAR_UTIL<op, scope, "space_cluster", asm_sem>.asm_str;
1170+   defvar intr_cluster = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1171+                                          "space_cluster", sem>.intr_name);
1172+ 
1173+   def _CTA : NVPTXInst<(outs B64:$state),
1174+              (ins ADDR:$addr, B32:$tx_count),
1175+              asm_cta # " $state, [$addr], $tx_count;",
1176+              [(set i64:$state, (intr_cta addr:$addr, i32:$tx_count))]>,
1177+              Requires<pred>;
1178+   def _CLUSTER : NVPTXInst<(outs),
1179+                  (ins ADDR:$addr, B32:$tx_count),
1180+                  asm_cluster # " _, [$addr], $tx_count;",
1181+                  [(intr_cluster addr:$addr, i32:$tx_count)]>,
1182+                  Requires<pred>;
1183+ }
1184+ foreach op = ["arrive", "arrive.expect_tx",
1185+               "arrive_drop", "arrive_drop.expect_tx"] in {
1186+   foreach scope = ["scope_cta", "scope_cluster"] in {
1187+     defvar suffix = !subst(".", "_", op) # scope;
1188+     defm mbar_ # suffix # _release : MBAR_ARR_INTR<op, scope, "", [hasPTX<80>, hasSM<90>]>;
1189+     defm mbar_ # suffix # _relaxed : MBAR_ARR_INTR<op, scope, "relaxed", [hasPTX<86>, hasSM<90>]>;
1190+   } // scope
1191+ } // op
1192+ 
1193+ multiclass MBAR_WAIT_INTR<string op, string scope, string sem, bit time_limit> {
1194+   // When either of sem or scope is non-default, both have to
1195+   // be explicitly specified. So, explicitly state that the
1196+   // semantics is `acquire` when the scope is `cluster`.
1197+   defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1198+                     "acquire", sem);
1199+ 
1200+   defvar asm_parity  = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1201+                                  time_limit, 1>.asm_str;
1202+   defvar pred_parity = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1203+                                  time_limit, 1>.preds;
1204+   defvar intr_parity = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1205+                                         sem, time_limit, 1>.intr_name);
1206+ 
1207+   defvar asm_state  = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1208+                                 time_limit>.asm_str;
1209+   defvar pred_state = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1210+                                 time_limit>.preds;
1211+   defvar intr_state = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1212+                                        sem, time_limit>.intr_name);
1213+ 
1214+   defvar ins_tl_dag = !if(time_limit, (ins B32:$tl), (ins));
1215+   defvar tl_suffix = !if(time_limit, ", $tl;", ";");
1216+   defvar intr_state_dag = !con((intr_state addr:$addr, i64:$state),
1217+                                !if(time_limit, (intr_state i32:$tl), (intr_state)));
1218+   defvar intr_parity_dag = !con((intr_parity addr:$addr, i32:$phase),
1219+                                !if(time_limit, (intr_parity i32:$tl), (intr_parity)));
1220+ 
1221+   def _STATE : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B64:$state), ins_tl_dag),
1222+                asm_state # " $res, [$addr], $state" # tl_suffix,
1223+                [(set i1:$res, intr_state_dag)]>,
1224+                Requires<pred_state>;
1225+   def _PARITY : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B32:$phase), ins_tl_dag),
1226+                 asm_parity # " $res, [$addr], $phase" # tl_suffix,
1227+                 [(set i1:$res, intr_parity_dag)]>,
1228+                 Requires<pred_parity>;
1229+ }
1230+ foreach op = ["test_wait", "try_wait"] in {
1231+   foreach scope = ["scope_cta", "scope_cluster"] in {
1232+     foreach time_limit = !if(!eq(op, "try_wait"), [true, false], [false]) in {
1233+       defvar suffix = StrJoin<"_", [op, scope, !if(time_limit, "tl", "")]>.ret;
1234+       defm mbar_ # suffix # "_acquire" : MBAR_WAIT_INTR<op, scope, "", time_limit>;
1235+       defm mbar_ # suffix # "_relaxed" : MBAR_WAIT_INTR<op, scope, "relaxed", time_limit>;
1236+     } // time_limit
1237+   } // scope
1238+ } // op
1239+ 
10851240//-----------------------------------
10861241// Math Functions
10871242//-----------------------------------
0 commit comments