Skip to content

Commit 941d2fd

Browse files
authored
[acc][mlir] Add 'if-condition' to 'atomic' operations. (llvm#164003)
OpenACC 3.4 includes the ability to add an 'if' to an atomic operation. From the change log: `Added the if clause to the atomic construct to enable conditional atomic operations based867 on the parallelism strategy employed` In 2.12, the C/C++ grammar is changed to say: `#pragma acc atomic [ atomic-clause ] [ if( condition ) ] new-line` With corresponding changes to the Fortran standard This patch adds support to this for the dialect, so that Clang can use it soon.
1 parent 0731f18 commit 941d2fd

File tree

4 files changed

+70
-10
lines changed

4 files changed

+70
-10
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter,
327327
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
328328

329329
mlir::acc::AtomicReadOp::create(firOpBuilder, loc, fromAddress, toAddress,
330-
mlir::TypeAttr::get(elementType));
330+
mlir::TypeAttr::get(elementType),
331+
/*ifCond=*/mlir::Value{});
331332
}
332333

333334
/// Used to generate atomic.write operation which is created in existing
@@ -347,7 +348,8 @@ genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter,
347348
rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
348349
firOpBuilder.restoreInsertionPoint(insertionPoint);
349350

350-
mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr);
351+
mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr,
352+
/*ifCond=*/mlir::Value{});
351353
}
352354

353355
/// Used to generate atomic.update operation which is created in existing
@@ -463,7 +465,8 @@ static inline void genAtomicUpdateStatement(
463465

464466
mlir::Operation *atomicUpdateOp = nullptr;
465467
atomicUpdateOp =
466-
mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr);
468+
mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr,
469+
/*ifCond=*/mlir::Value{});
467470

468471
llvm::SmallVector<mlir::Type> varTys = {varType};
469472
llvm::SmallVector<mlir::Location> locs = {currentLocation};
@@ -588,7 +591,9 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
588591
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
589592

590593
mlir::Operation *atomicCaptureOp = nullptr;
591-
atomicCaptureOp = mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc);
594+
atomicCaptureOp =
595+
mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc,
596+
/*ifCond=*/mlir::Value{});
592597

593598
firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
594599
mlir::Block &block = atomicCaptureOp->getRegion(0).back();

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,10 +2787,14 @@ def AtomicReadOp : OpenACC_Op<"atomic.read", [AtomicReadOpInterface]> {
27872787

27882788
let arguments = (ins OpenACC_PointerLikeType:$x,
27892789
OpenACC_PointerLikeType:$v,
2790-
TypeAttr:$element_type);
2790+
TypeAttr:$element_type, Optional<I1>:$ifCond);
27912791
let assemblyFormat = [{
2792+
oilist(
2793+
`if` `(` $ifCond `)`
2794+
)
27922795
$v `=` $x
2793-
`:` type($v) `,` type($x) `,` $element_type attr-dict
2796+
`:` type($v) `,` type($x) `,` $element_type
2797+
attr-dict
27942798
}];
27952799
let hasVerifier = 1;
27962800
}
@@ -2809,8 +2813,12 @@ def AtomicWriteOp : OpenACC_Op<"atomic.write",[AtomicWriteOpInterface]> {
28092813
}];
28102814

28112815
let arguments = (ins OpenACC_PointerLikeType:$x,
2812-
AnyType:$expr);
2816+
AnyType:$expr,
2817+
Optional<I1>:$ifCond);
28132818
let assemblyFormat = [{
2819+
oilist(
2820+
`if` `(` $ifCond `)`
2821+
)
28142822
$x `=` $expr
28152823
`:` type($x) `,` type($expr)
28162824
attr-dict
@@ -2850,10 +2858,15 @@ def AtomicUpdateOp : OpenACC_Op<"atomic.update",
28502858

28512859
let arguments = (ins Arg<OpenACC_PointerLikeType,
28522860
"Address of variable to be updated",
2853-
[MemRead, MemWrite]>:$x);
2861+
[MemRead, MemWrite]>:$x,
2862+
Optional<I1>:$ifCond);
28542863
let regions = (region SizedRegion<1>:$region);
28552864
let assemblyFormat = [{
2856-
$x `:` type($x) $region attr-dict
2865+
oilist(
2866+
`if` `(` $ifCond `)`
2867+
)
2868+
$x `:` type($x)
2869+
$region attr-dict
28572870
}];
28582871
let hasVerifier = 1;
28592872
let hasRegionVerifier = 1;
@@ -2896,8 +2909,13 @@ def AtomicCaptureOp : OpenACC_Op<"atomic.capture",
28962909

28972910
}];
28982911

2912+
let arguments = (ins Optional<I1>:$ifCond);
2913+
28992914
let regions = (region SizedRegion<1>:$region);
29002915
let assemblyFormat = [{
2916+
oilist(
2917+
`if` `(` $ifCond `)`
2918+
)
29012919
$region attr-dict
29022920
}];
29032921
let hasRegionVerifier = 1;

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3842,7 +3842,8 @@ LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
38423842
}
38433843

38443844
if (Value writeVal = op.getWriteOpVal()) {
3845-
rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3845+
rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
3846+
op.getIfCond());
38463847
return success();
38473848
}
38483849

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,12 @@ acc.set default_async(%i32Value : i32)
17661766
func.func @acc_atomic_read(%v: memref<i32>, %x: memref<i32>) {
17671767
// CHECK: acc.atomic.read %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
17681768
acc.atomic.read %v = %x : memref<i32>, memref<i32>, i32
1769+
1770+
// CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true
1771+
// CHECK-NEXT: acc.atomic.read if(%[[IFCOND1]]) %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
1772+
%ifCond = arith.constant true
1773+
acc.atomic.read if(%ifCond) %v = %x : memref<i32>, memref<i32>, i32
1774+
17691775
return
17701776
}
17711777

@@ -1776,6 +1782,12 @@ func.func @acc_atomic_read(%v: memref<i32>, %x: memref<i32>) {
17761782
func.func @acc_atomic_write(%addr : memref<i32>, %val : i32) {
17771783
// CHECK: acc.atomic.write %[[ADDR]] = %[[VAL]] : memref<i32>, i32
17781784
acc.atomic.write %addr = %val : memref<i32>, i32
1785+
1786+
// CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true
1787+
// CHECK-NEXT: acc.atomic.write if(%[[IFCOND1]]) %[[ADDR]] = %[[VAL]] : memref<i32>, i32
1788+
%ifCond = arith.constant true
1789+
acc.atomic.write if(%ifCond) %addr = %val : memref<i32>, i32
1790+
17791791
return
17801792
}
17811793

@@ -1793,6 +1805,19 @@ func.func @acc_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
17931805
%newval = llvm.add %xval, %expr : i32
17941806
acc.yield %newval : i32
17951807
}
1808+
1809+
// CHECK: %[[IFCOND1:.*]] = arith.constant true
1810+
// CHECK-NEXT: acc.atomic.update if(%[[IFCOND1]]) %[[X]] : memref<i32>
1811+
// CHECK-NEXT: (%[[XVAL:.*]]: i32):
1812+
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.add %[[XVAL]], %[[EXPR]] : i32
1813+
// CHECK-NEXT: acc.yield %[[NEWVAL]] : i32
1814+
%ifCond = arith.constant true
1815+
acc.atomic.update if (%ifCond) %x : memref<i32> {
1816+
^bb0(%xval: i32):
1817+
%newval = llvm.add %xval, %expr : i32
1818+
acc.yield %newval : i32
1819+
}
1820+
17961821
// CHECK: acc.atomic.update %[[XBOOL]] : memref<i1>
17971822
// CHECK-NEXT: (%[[XVAL:.*]]: i1):
17981823
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.and %[[XVAL]], %[[EXPRBOOL]] : i1
@@ -1902,6 +1927,17 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
19021927
acc.atomic.write %x = %expr : memref<i32>, i32
19031928
}
19041929

1930+
// CHECK: %[[IFCOND1:.*]] = arith.constant true
1931+
// CHECK-NEXT: acc.atomic.capture if(%[[IFCOND1]]) {
1932+
// CHECK-NEXT: acc.atomic.read %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
1933+
// CHECK-NEXT: acc.atomic.write %[[x]] = %[[expr]] : memref<i32>, i32
1934+
// CHECK-NEXT: }
1935+
%ifCond = arith.constant true
1936+
acc.atomic.capture if (%ifCond) {
1937+
acc.atomic.read %v = %x : memref<i32>, memref<i32>, i32
1938+
acc.atomic.write %x = %expr : memref<i32>, i32
1939+
}
1940+
19051941
return
19061942
}
19071943

0 commit comments

Comments
 (0)