Skip to content

Commit ebc35f8

Browse files
authored
[mlir][NVVM] Make sure barrier reduction attr can roundtrip (#167958)
The IR was not able to be roundtrip through mlir-opt. Update the assembly format and add round trip tests. ``` mlir-opt mlir/test/Target/LLVMIR/nvvm/barrier.mlir | mlir-opt <stdin>:6:5: error: cannot name an operation with no results %0 = nvvm.barrier <and> %arg2 -> i32 ```
1 parent 6a89439 commit ebc35f8

File tree

3 files changed

+26
-19
lines changed

3 files changed

+26
-19
lines changed

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,24 @@ end
103103
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
104104
! CHECK: nvvm.barrier0
105105
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
106-
! CHECK: %{{.*}} = nvvm.barrier <and> %c1{{.*}} -> i32
106+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %c1{{.*}} -> i32
107107
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
108108
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
109109
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
110110
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
111-
! CHECK: %{{.*}} = nvvm.barrier <and> %[[CONV]] -> i32
112-
! CHECK: %{{.*}} = nvvm.barrier <popc> %c1{{.*}} -> i32
111+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %[[CONV]] -> i32
112+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %c1{{.*}} -> i32
113113
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
114114
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
115115
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
116116
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
117-
! CHECK: %{{.*}} = nvvm.barrier <popc> %[[CONV]] -> i32
118-
! CHECK: %{{.*}} = nvvm.barrier <or> %c1{{.*}} -> i32
117+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %[[CONV]] -> i32
118+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %c1{{.*}} -> i32
119119
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
120120
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
121121
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
122122
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
123-
! CHECK: %{{.*}} = nvvm.barrier <or> %[[CONV]] -> i32
123+
! CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %[[CONV]] -> i32
124124
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
125125
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
126126
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
@@ -214,9 +214,9 @@ end
214214
! CHECK: cuf.kernel
215215
! CHECK: nvvm.barrier0
216216
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
217-
! CHECK: nvvm.barrier <and> %c1{{.*}} -> i32
218-
! CHECK: nvvm.barrier <popc> %c1{{.*}} -> i32
219-
! CHECK: nvvm.barrier <or> %c1{{.*}} -> i32
217+
! CHECK: nvvm.barrier #nvvm.reduction<and> %c1{{.*}} -> i32
218+
! CHECK: nvvm.barrier #nvvm.reduction<popc> %c1{{.*}} -> i32
219+
! CHECK: nvvm.barrier #nvvm.reduction<or> %c1{{.*}} -> i32
220220

221221
attributes(device) subroutine testMatch()
222222
integer :: a, ipred, mask, v32

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
994994

995995
let assemblyFormat =
996996
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
997-
"($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
997+
"(qualified($reductionOp)^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
998998

999999
let builders = [OpBuilder<(ins), [{
10001000
return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});

mlir/test/Target/LLVMIR/nvvm/barrier.mlir

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1-
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
1+
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s --check-prefix=LLVM
2+
// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
23

3-
// CHECK-LABEL: @llvm_nvvm_barrier(
4-
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
4+
// LLVM-LABEL: @llvm_nvvm_barrier(
5+
// LLVM-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
56
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) {
6-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
7+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
8+
// CHECK: nvvm.barrier
79
nvvm.barrier
8-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
10+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
11+
// CHECK: nvvm.barrier id = %{{.*}}
912
nvvm.barrier id = %barID
10-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
13+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
14+
// CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}}
1115
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
12-
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
16+
// LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
17+
// CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %{{.*}} -> i32
1318
%0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32
14-
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
19+
// LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
20+
// CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %{{.*}} -> i32
1521
%1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32
16-
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
22+
// LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
23+
// CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %{{.*}} -> i32
1724
%2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32
1825

1926
llvm.return

0 commit comments

Comments
 (0)