Skip to content

Commit 7102e6c

Browse files
committed
[mlir][OpenACC][OpenMP] Modify atomic capture to allow update/write
OpenACC's C++ version has a variant of capture that permits a update followed by a write. Therefore the verifier was overly strict in this case. According to our reading of the OpenMP 6.0 spec, it appears that `atomic captured update` (page 495) also requires this form, so it seems reasonable to allow this for both languages.
1 parent 6429549 commit 7102e6c

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,12 @@ def AtomicCaptureOp : OpenACC_Op<"atomic.capture",
29992999
acc.atomic.write ...
30003000
acc.terminator
30013001
}
3002+
3003+
acc.atomic.capture {
3004+
acc.atomic.update ...
3005+
acc.atomic.write ...
3006+
acc.terminator
3007+
}
30023008
```
30033009

30043010
}];

mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
239239
implement one of the atomic interfaces. It can be found in one of these
240240
forms:
241241
`{ atomic.update, atomic.read }`
242+
`{ atomic.update, atomic.write }`
242243
`{ atomic.read, atomic.update }`
243244
`{ atomic.read, atomic.write }`
244245
}];
@@ -291,12 +292,15 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
291292
auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
292293

293294
if (!((firstUpdateStmt && secondReadStmt) ||
295+
(firstUpdateStmt && secondWriteStmt) ||
294296
(firstReadStmt && secondUpdateStmt) ||
295297
(firstReadStmt && secondWriteStmt)))
296298
return ops.front().emitError()
297299
<< "invalid sequence of operations in the capture region";
298-
if (firstUpdateStmt && secondReadStmt &&
299-
firstUpdateStmt.getX() != secondReadStmt.getX())
300+
if ((firstUpdateStmt && secondReadStmt &&
301+
firstUpdateStmt.getX() != secondReadStmt.getX()) ||
302+
(firstUpdateStmt && secondWriteStmt &&
303+
firstUpdateStmt.getX() != secondWriteStmt.getX()))
300304
return firstUpdateStmt.emitError()
301305
<< "updated variable in atomic.update must be captured in "
302306
"second operation";

mlir/test/Dialect/OpenACC/invalid.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,6 @@ func.func @acc_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
690690

691691
func.func @acc_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
692692
acc.atomic.capture {
693-
// expected-error @below {{invalid sequence of operations in the capture region}}
694693
acc.atomic.update %x : memref<i32> {
695694
^bb0(%xval: i32):
696695
%newval = llvm.add %xval, %expr : i32
@@ -704,6 +703,23 @@ func.func @acc_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
704703

705704
// -----
706705

706+
func.func @acc_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
707+
acc.atomic.capture {
708+
// expected-error @below {{updated variable in atomic.update must be captured in second operation}}
709+
acc.atomic.update %x : memref<i32> {
710+
^bb0(%xval: i32):
711+
%newval = llvm.add %xval, %expr : i32
712+
acc.yield %newval : i32
713+
}
714+
acc.atomic.write %v = %expr : memref<i32>, i32
715+
716+
acc.terminator
717+
}
718+
return
719+
}
720+
721+
// -----
722+
707723
func.func @acc_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
708724
acc.atomic.capture {
709725
// expected-error @below {{invalid sequence of operations in the capture region}}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,22 @@ func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
12631263

12641264
func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
12651265
omp.atomic.capture {
1266-
// expected-error @below {{invalid sequence of operations in the capture region}}
1266+
// expected-error @below {{updated variable in atomic.update must be captured in second operation}}
1267+
omp.atomic.update %x : memref<i32> {
1268+
^bb0(%xval: i32):
1269+
%newval = llvm.add %xval, %expr : i32
1270+
omp.yield (%newval : i32)
1271+
}
1272+
omp.atomic.write %v = %expr : memref<i32>, i32
1273+
omp.terminator
1274+
}
1275+
return
1276+
}
1277+
1278+
// -----
1279+
1280+
func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
1281+
omp.atomic.capture {
12671282
omp.atomic.update %x : memref<i32> {
12681283
^bb0(%xval: i32):
12691284
%newval = llvm.add %xval, %expr : i32
@@ -1289,6 +1304,22 @@ func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
12891304

12901305
// -----
12911306

1307+
func.func @omp_atomic_capture(%x: memref<i32>, %v: memref<i32>, %expr: i32) {
1308+
omp.atomic.capture {
1309+
omp.atomic.update %x : memref<i32> {
1310+
^bb0(%xval: i32):
1311+
%newval = llvm.add %xval, %expr : i32
1312+
omp.yield (%newval : i32)
1313+
}
1314+
omp.atomic.write %x = %expr : memref<i32>, i32
1315+
1316+
omp.terminator
1317+
}
1318+
return
1319+
}
1320+
1321+
// -----
1322+
12921323
func.func @omp_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>, %expr: i32) {
12931324
omp.atomic.capture {
12941325
// expected-error @below {{updated variable in atomic.update must be captured in second operation}}

0 commit comments

Comments
 (0)