Skip to content

Commit 606a0c2

Browse files
authored
[flang][cuda][NFC] Use NVVM barrier op with reduction (#167940)
Simplify the lowering by using the barrier op from NVVM updated in #167036
1 parent 92e5608 commit 606a0c2

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,42 +1080,39 @@ void CUDAIntrinsicLibrary::genSyncThreads(
10801080
mlir::Value
10811081
CUDAIntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
10821082
llvm::ArrayRef<mlir::Value> args) {
1083-
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
1084-
mlir::MLIRContext *context = builder.getContext();
1085-
mlir::Type i32 = builder.getI32Type();
1086-
mlir::FunctionType ftype =
1087-
mlir::FunctionType::get(context, {resultType}, {i32});
1088-
auto funcOp = builder.createFunction(loc, funcName, ftype);
1089-
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
1090-
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
1083+
mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
1084+
return mlir::NVVM::BarrierOp::create(
1085+
builder, loc, resultType, {}, {},
1086+
mlir::NVVM::BarrierReductionAttr::get(
1087+
builder.getContext(), mlir::NVVM::BarrierReduction::AND),
1088+
arg)
1089+
.getResult(0);
10911090
}
10921091

10931092
// SYNCTHREADS_COUNT
10941093
mlir::Value
10951094
CUDAIntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
10961095
llvm::ArrayRef<mlir::Value> args) {
1097-
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
1098-
mlir::MLIRContext *context = builder.getContext();
1099-
mlir::Type i32 = builder.getI32Type();
1100-
mlir::FunctionType ftype =
1101-
mlir::FunctionType::get(context, {resultType}, {i32});
1102-
auto funcOp = builder.createFunction(loc, funcName, ftype);
1103-
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
1104-
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
1096+
mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
1097+
return mlir::NVVM::BarrierOp::create(
1098+
builder, loc, resultType, {}, {},
1099+
mlir::NVVM::BarrierReductionAttr::get(
1100+
builder.getContext(), mlir::NVVM::BarrierReduction::POPC),
1101+
arg)
1102+
.getResult(0);
11051103
}
11061104

11071105
// SYNCTHREADS_OR
11081106
mlir::Value
11091107
CUDAIntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
11101108
llvm::ArrayRef<mlir::Value> args) {
1111-
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
1112-
mlir::MLIRContext *context = builder.getContext();
1113-
mlir::Type i32 = builder.getI32Type();
1114-
mlir::FunctionType ftype =
1115-
mlir::FunctionType::get(context, {resultType}, {i32});
1116-
auto funcOp = builder.createFunction(loc, funcName, ftype);
1117-
mlir::Value arg = builder.createConvert(loc, i32, args[0]);
1118-
return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
1109+
mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
1110+
return mlir::NVVM::BarrierOp::create(
1111+
builder, loc, resultType, {}, {},
1112+
mlir::NVVM::BarrierReductionAttr::get(
1113+
builder.getContext(), mlir::NVVM::BarrierReduction::OR),
1114+
arg)
1115+
.getResult(0);
11191116
}
11201117

11211118
// SYNCWARP

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,24 @@ end
103103
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
104104
! CHECK: nvvm.barrier0
105105
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
106-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
106+
! CHECK: %{{.*}} = nvvm.barrier <and> %c1{{.*}} -> i32
107107
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
108108
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
109109
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
110110
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
111-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CONV]])
112-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
111+
! CHECK: %{{.*}} = nvvm.barrier <and> %[[CONV]] -> i32
112+
! CHECK: %{{.*}} = nvvm.barrier <popc> %c1{{.*}} -> i32
113113
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
114114
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
115115
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
116116
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
117-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CONV]]) fastmath<contract> : (i32) -> i32
118-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
117+
! CHECK: %{{.*}} = nvvm.barrier <popc> %[[CONV]] -> i32
118+
! CHECK: %{{.*}} = nvvm.barrier <or> %c1{{.*}} -> i32
119119
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
120120
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
121121
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
122122
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
123-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CONV]]) fastmath<contract> : (i32) -> i32
123+
! CHECK: %{{.*}} = nvvm.barrier <or> %[[CONV]] -> i32
124124
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
125125
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
126126
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
@@ -214,9 +214,9 @@ end
214214
! CHECK: cuf.kernel
215215
! CHECK: nvvm.barrier0
216216
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
217-
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
218-
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
219-
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
217+
! CHECK: nvvm.barrier <and> %c1{{.*}} -> i32
218+
! CHECK: nvvm.barrier <popc> %c1{{.*}} -> i32
219+
! CHECK: nvvm.barrier <or> %c1{{.*}} -> i32
220220

221221
attributes(device) subroutine testMatch()
222222
integer :: a, ipred, mask, v32

0 commit comments

Comments
 (0)