Skip to content

Commit eb6f233

Browse files
committed
[MLIR][NVVM] Update mbarrier.test.wait Op
This patch extends the mbarrier.test.wait Op to support scope, semantics and phase-parity. This completes updates to the test_wait up-to Blackwell. lit tests are added to verify the lowering. Signed-off-by: Durgadoss R <[email protected]>
1 parent 5c26015 commit eb6f233

File tree

6 files changed

+142
-60
lines changed

6 files changed

+142
-60
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,31 +1052,35 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
10521052
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
10531053
}
10541054

1055-
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
1056-
Results<(outs I1:$res)>,
1057-
Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
1058-
I64:$state)> {
1055+
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> {
10591056
let summary = "MBarrier Non-Blocking Test Wait Operation";
10601057
let description = [{
1061-
The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the
1058+
The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the
10621059
completion of a specific phase of an *mbarrier object*. It uses the default
1063-
`.acquire.cta` semantics. This acquire pattern establishes memory ordering for
1064-
operations occurring in program order after this wait instruction by making
1065-
operations from other threads in the CTA visible to subsequent operations in the current
1066-
thread. When this wait completes, it synchronizes with the corresponding release
1067-
pattern from the `mbarrier.arrive` operation, establishing memory ordering within
1060+
`.acquire.cta` semantics. This acquire pattern establishes memory ordering for
1061+
operations occurring in program order after this wait instruction by making
1062+
operations from other threads in the CTA visible to subsequent operations in the current
1063+
thread. When this wait completes, it synchronizes with the corresponding release
1064+
pattern from the `mbarrier.arrive` operation, establishing memory ordering within
10681065
the CTA.
10691066

1070-
This operation tests whether the mbarrier phase specified by the state operand
1071-
has completed. It is a non-blocking instruction that immediately returns the
1067+
This operation tests whether the mbarrier phase specified by the state operand
1068+
has completed. It is a non-blocking instruction that immediately returns the
10721069
completion status without suspending the executing thread.
10731070

10741071
The operation takes the following operands:
1075-
- `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic
1072+
- `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic
10761073
addressing, but the address must still be in the shared memory space.
1077-
- `state`: An opaque value returned by a previous `mbarrier.arrive`
1078-
operation on the same *mbarrier object* during the current or immediately
1079-
preceding phase.
1074+
- `stateOrPhase`: This argument represents a `state` when it is a 64-bit value
1075+
and represents a `phase` when it is a 32-bit value. The `state` is an opaque
1076+
value returned by a previous `mbarrier.arrive` operation on the same
1077+
*mbarrier object* during the current or immediately preceding phase.
1078+
The `phase` is an integer specifying the phase parity (0 or 1).
1079+
Even phases have parity 0, odd phases have parity 1.
1080+
- `scope`: This specifies the set of threads that directly observe the memory
1081+
synchronizing effect of the `mbarrier.test.wait` operation.
1082+
- `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics
1083+
and does not provide any ordering or visibility guarantees.
10801084

10811085
The operation returns a boolean value indicating whether the specified phase
10821086
has completed:
@@ -1103,7 +1107,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
11031107
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
11041108
}];
11051109

1106-
let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
1110+
let results = (outs I1:$res);
1111+
let arguments = (ins
1112+
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
1113+
AnyTypeOf<[I64, I32]>:$stateOrPhase,
1114+
DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope,
1115+
DefaultValuedAttr<BoolAttr, "false">:$relaxed);
1116+
1117+
let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)";
1118+
let hasVerifier = 1;
11071119

11081120
let extraClassDeclaration = [{
11091121
static mlir::NVVM::IDArgPair

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
252252
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253253
NVVM::MemScopeKind scope,
254254
Value retVal = nullptr) {
255-
bool isSharedCluster = isPtrInSharedClusterSpace(addr);
256255
if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
257256
return op->emitError("mbarrier scope must be either CTA or Cluster");
258257

258+
bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259259
bool hasRetValue = static_cast<bool>(retVal);
260260
if (isSharedCluster && hasRetValue)
261261
return op->emitError(
@@ -310,6 +310,10 @@ LogicalResult MBarrierCompleteTxOp::verify() {
310310
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
311311
}
312312

313+
LogicalResult MBarrierTestWaitOp::verify() {
314+
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
315+
}
316+
313317
LogicalResult ConvertFloatToTF32Op::verify() {
314318
using RndMode = NVVM::FPRoundingMode;
315319
switch (getRnd()) {
@@ -2718,16 +2722,34 @@ mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
27182722
mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
27192723
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
27202724
auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
2721-
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
2722-
llvm::Intrinsic::ID id = isShared
2723-
? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
2724-
: llvm::Intrinsic::nvvm_mbarrier_test_wait;
2725-
// Fill the Intrinsic Args
2726-
llvm::SmallVector<llvm::Value *> args;
2727-
args.push_back(mt.lookupValue(thisOp.getAddr()));
2728-
args.push_back(mt.lookupValue(thisOp.getState()));
2725+
bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
2726+
bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
2727+
// bit-0: isPhaseParity
2728+
// bit-1: Scope
2729+
size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
27292730

2730-
return {id, std::move(args)};
2731+
// clang-format off
2732+
static constexpr llvm::Intrinsic::ID IDs[] = {
2733+
llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
2734+
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
2735+
llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
2736+
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
2737+
static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
2738+
llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
2739+
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
2740+
llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
2741+
llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
2742+
// clang-format on
2743+
auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
2744+
2745+
// Tidy-up the Intrinsic Args
2746+
llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
2747+
llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
2748+
bool needCast = isPtrInGenericSpace(thisOp.getAddr());
2749+
if (needCast)
2750+
mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
2751+
2752+
return {id, {mbar, input}};
27312753
}
27322754

27332755
mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -464,19 +464,6 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
464464
llvm.return
465465
}
466466

467-
llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
468-
// CHECK: nvvm.mbarrier.test.wait %{{.*}}
469-
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
470-
llvm.return %isComplete : i1
471-
}
472-
473-
llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
474-
%count = nvvm.read.ptx.sreg.ntid.x : i32
475-
// CHECK: nvvm.mbarrier.test.wait %{{.*}}
476-
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
477-
llvm.return
478-
}
479-
480467
// CHECK-LABEL: @wgmma_fence_aligned
481468
func.func @wgmma_fence_aligned() {
482469
// CHECK: nvvm.wgmma.fence.aligned

mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,3 @@ llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
5454
nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
5555
llvm.return
5656
}
57-
58-
llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
59-
// CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
60-
// CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
61-
// CHECK-NEXT: ret i1 %3
62-
// CHECK-NEXT: }
63-
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
64-
llvm.return %isComplete : i1
65-
}
66-
67-
llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
68-
// CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
69-
// CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
70-
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
71-
// CHECK-NEXT: ret void
72-
// CHECK-NEXT: }
73-
%count = nvvm.read.ptx.sreg.ntid.x : i32
74-
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
75-
llvm.return
76-
}

mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,11 @@ llvm.func @mbarrier_arr_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count
112112
llvm.return
113113
}
114114

115+
// -----
116+
117+
llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
118+
// expected-error @below {{mbarrier scope must be either CTA or Cluster}}
119+
%1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1
120+
llvm.return
121+
}
122+
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) {
4+
// CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) {
5+
// CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
6+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
7+
// CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
8+
// CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
9+
// CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
10+
// CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
11+
// CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
12+
// CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
13+
// CHECK-NEXT: ret void
14+
// CHECK-NEXT: }
15+
%0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1
16+
%1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
17+
18+
%2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
19+
%3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
20+
llvm.return
21+
}
22+
23+
llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
24+
// CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
25+
// CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
26+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
27+
// CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
28+
// CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
29+
// CHECK-NEXT: ret void
30+
// CHECK-NEXT: }
31+
%0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
32+
%1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
33+
34+
%2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
35+
%3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
36+
llvm.return
37+
}
38+
39+
llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
40+
// CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) {
41+
// CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
42+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
43+
// CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
44+
// CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
45+
// CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
46+
// CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
47+
// CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
48+
// CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
49+
// CHECK-NEXT: ret void
50+
// CHECK-NEXT: }
51+
%0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1
52+
%1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
53+
54+
%2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
55+
%3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
56+
llvm.return
57+
}
58+
59+
llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
60+
// CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
61+
// CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
62+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
63+
// CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
64+
// CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
65+
// CHECK-NEXT: ret void
66+
// CHECK-NEXT: }
67+
%0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
68+
%1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
69+
70+
%2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
71+
%3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
72+
llvm.return
73+
}

0 commit comments

Comments
 (0)