Skip to content

Commit 4f3be47

Browse files
committed
Address review comments
1 parent 0308d0d commit 4f3be47

File tree

4 files changed

+26
-38
lines changed

4 files changed

+26
-38
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,24 +1236,8 @@ def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">,
12361236
let hasVerifier = 1;
12371237
}
12381238

1239-
// Attrs describing the level of the Memory Operation
1240-
def MemLevelCTA : I32EnumAttrCase<"CTA", 0, "cta">;
1241-
def MemLevelGL : I32EnumAttrCase<"GL", 1, "gl">;
1242-
def MemLevelSys : I32EnumAttrCase<"SYS", 2, "sys">;
1243-
1244-
def MemLevelKind
1245-
: I32EnumAttr<
1246-
"MemLevelKind",
1247-
"NVVM Memory Level kind", [MemLevelCTA, MemLevelGL, MemLevelSys]> {
1248-
let genSpecializedAttr = 0;
1249-
let cppNamespace = "::mlir::NVVM";
1250-
}
1251-
def MemLevelKindAttr : EnumAttr<NVVM_Dialect, MemLevelKind, "mem_level"> {
1252-
let assemblyFormat = "`<` $value `>`";
1253-
}
1254-
1255-
def NVVM_MembarOp : NVVM_Op<"membar">,
1256-
Arguments<(ins MemLevelKindAttr:$level)> {
1239+
def NVVM_MembarOp : NVVM_Op<"memory_barrier">,
1240+
Arguments<(ins MemScopeKindAttr:$scope)> {
12571241
let summary = "Memory barrier operation";
12581242
let description = [{
12591243
`membar` operation guarantees that prior memory accesses requested by this
@@ -1263,9 +1247,9 @@ def NVVM_MembarOp : NVVM_Op<"membar">,
12631247
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar)
12641248
}];
12651249

1266-
let assemblyFormat = "$level attr-dict";
1250+
let assemblyFormat = "$scope attr-dict";
12671251
let llvmBuilder = [{
1268-
createIntrinsicCall(builder, getMembarLevelID($level), {});
1252+
createIntrinsicCall(builder, getMemoryBarrierLevelID($scope), {});
12691253
}];
12701254
}
12711255

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,18 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
291291
llvm_unreachable("Unsupported proxy kinds");
292292
}
293293

294-
static unsigned getMembarLevelID(NVVM::MemLevelKind level) {
295-
switch (level) {
296-
case NVVM::MemLevelKind::CTA: {
294+
static unsigned getMemoryBarrierLevelID(NVVM::MemScopeKind scope) {
295+
switch (scope) {
296+
case NVVM::MemScopeKind::CTA: {
297297
return llvm::Intrinsic::nvvm_membar_cta;
298298
}
299-
case NVVM::MemLevelKind::GL: {
299+
case NVVM::MemScopeKind::CLUSTER: {
300+
return llvm::Intrinsic::nvvm_fence_sc_cluster;
301+
}
302+
case NVVM::MemScopeKind::GPU: {
300303
return llvm::Intrinsic::nvvm_membar_gl;
301304
}
302-
case NVVM::MemLevelKind::SYS: {
305+
case NVVM::MemScopeKind::SYS: {
303306
return llvm::Intrinsic::nvvm_membar_sys;
304307
}
305308
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
2+
3+
// CHECK-lABEL: @memorybarrier()
4+
llvm.func @memorybarrier() {
5+
// CHECK: call void @llvm.nvvm.membar.cta()
6+
nvvm.memory_barrier #nvvm.mem_scope<cta>
7+
// CHECK: call void @llvm.nvvm.fence.sc.cluster()
8+
nvvm.memory_barrier #nvvm.mem_scope<cluster>
9+
// CHECK: call void @llvm.nvvm.membar.gl()
10+
nvvm.memory_barrier #nvvm.mem_scope<gpu>
11+
// CHECK: call void @llvm.nvvm.membar.sys()
12+
nvvm.memory_barrier #nvvm.mem_scope<sys>
13+
llvm.return
14+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -975,16 +975,3 @@ llvm.func @nanosleep() {
975975
nvvm.nanosleep 4000
976976
llvm.return
977977
}
978-
979-
// -----
980-
981-
// CHECK-lABEL: @memorybarrier()
982-
llvm.func @memorybarrier() {
983-
// CHECK: call void @llvm.nvvm.membar.cta()
984-
nvvm.membar #nvvm.mem_level<cta>
985-
// CHECK: call void @llvm.nvvm.membar.gl()
986-
nvvm.membar #nvvm.mem_level<gl>
987-
// CHECK: call void @llvm.nvvm.membar.sys()
988-
nvvm.membar #nvvm.mem_level<sys>
989-
llvm.return
990-
}

0 commit comments

Comments
 (0)