diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index cfb18914e8126..77a5716ab9c04 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -327,7 +327,8 @@ genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::acc::AtomicReadOp::create(firOpBuilder, loc, fromAddress, toAddress, - mlir::TypeAttr::get(elementType)); + mlir::TypeAttr::get(elementType), + /*ifCond=*/mlir::Value{}); } /// Used to generate atomic.write operation which is created in existing @@ -347,7 +348,8 @@ genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter, rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr); firOpBuilder.restoreInsertionPoint(insertionPoint); - mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr); + mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr, + /*ifCond=*/mlir::Value{}); } /// Used to generate atomic.update operation which is created in existing @@ -463,7 +465,8 @@ static inline void genAtomicUpdateStatement( mlir::Operation *atomicUpdateOp = nullptr; atomicUpdateOp = - mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr); + mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr, + /*ifCond=*/mlir::Value{}); llvm::SmallVector varTys = {varType}; llvm::SmallVector locs = {currentLocation}; @@ -588,7 +591,9 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter, fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType(); mlir::Operation *atomicCaptureOp = nullptr; - atomicCaptureOp = mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc); + atomicCaptureOp = + mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc, + /*ifCond=*/mlir::Value{}); firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0))); mlir::Block &block = atomicCaptureOp->getRegion(0).back(); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 1eaa21b46554c..e78cdbe3d3e64 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2787,10 +2787,14 @@ def AtomicReadOp : OpenACC_Op<"atomic.read", [AtomicReadOpInterface]> { let arguments = (ins OpenACC_PointerLikeType:$x, OpenACC_PointerLikeType:$v, - TypeAttr:$element_type); + TypeAttr:$element_type, Optional:$ifCond); let assemblyFormat = [{ + oilist( + `if` `(` $ifCond `)` + ) $v `=` $x - `:` type($v) `,` type($x) `,` $element_type attr-dict + `:` type($v) `,` type($x) `,` $element_type + attr-dict }]; let hasVerifier = 1; } @@ -2809,8 +2813,12 @@ def AtomicWriteOp : OpenACC_Op<"atomic.write",[AtomicWriteOpInterface]> { }]; let arguments = (ins OpenACC_PointerLikeType:$x, - AnyType:$expr); + AnyType:$expr, + Optional:$ifCond); let assemblyFormat = [{ + oilist( + `if` `(` $ifCond `)` + ) $x `=` $expr `:` type($x) `,` type($expr) attr-dict @@ -2850,10 +2858,15 @@ def AtomicUpdateOp : OpenACC_Op<"atomic.update", let arguments = (ins Arg:$x); + [MemRead, MemWrite]>:$x, + Optional:$ifCond); let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ - $x `:` type($x) $region attr-dict + oilist( + `if` `(` $ifCond `)` + ) + $x `:` type($x) + $region attr-dict }]; let hasVerifier = 1; let hasRegionVerifier = 1; @@ -2896,8 +2909,13 @@ def AtomicCaptureOp : OpenACC_Op<"atomic.capture", }]; + let arguments = (ins Optional:$ifCond); + let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ + oilist( + `if` `(` $ifCond `)` + ) $region attr-dict }]; let hasRegionVerifier = 1; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index dcfe2c742407e..c7811fb02f2c7 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -3842,7 +3842,8 @@ LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, } if (Value writeVal = op.getWriteOpVal()) { - rewriter.replaceOpWithNewOp(op, op.getX(), writeVal); + rewriter.replaceOpWithNewOp(op, op.getX(), writeVal, + op.getIfCond()); return success(); } diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 1484d7efd87c2..8713689ed5799 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1766,6 +1766,12 @@ acc.set default_async(%i32Value : i32) func.func @acc_atomic_read(%v: memref, %x: memref) { // CHECK: acc.atomic.read %[[v]] = %[[x]] : memref, memref, i32 acc.atomic.read %v = %x : memref, memref, i32 + + // CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true + // CHECK-NEXT: acc.atomic.read if(%[[IFCOND1]]) %[[v]] = %[[x]] : memref, memref, i32 + %ifCond = arith.constant true + acc.atomic.read if(%ifCond) %v = %x : memref, memref, i32 + return } @@ -1776,6 +1782,12 @@ func.func @acc_atomic_read(%v: memref, %x: memref) { func.func @acc_atomic_write(%addr : memref, %val : i32) { // CHECK: acc.atomic.write %[[ADDR]] = %[[VAL]] : memref, i32 acc.atomic.write %addr = %val : memref, i32 + + // CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true + // CHECK-NEXT: acc.atomic.write if(%[[IFCOND1]]) %[[ADDR]] = %[[VAL]] : memref, i32 + %ifCond = arith.constant true + acc.atomic.write if(%ifCond) %addr = %val : memref, i32 + return } @@ -1793,6 +1805,19 @@ func.func @acc_atomic_update(%x : memref, %expr : i32, %xBool : memref, %newval = llvm.add %xval, %expr : i32 acc.yield %newval : i32 } + + // CHECK: %[[IFCOND1:.*]] = arith.constant true + // CHECK-NEXT: acc.atomic.update if(%[[IFCOND1]]) %[[X]] : memref + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.add %[[XVAL]], %[[EXPR]] : i32 + // CHECK-NEXT: acc.yield %[[NEWVAL]] : i32 + %ifCond = arith.constant true + acc.atomic.update if (%ifCond) %x : memref { + ^bb0(%xval: i32): + %newval = llvm.add %xval, %expr : i32 + acc.yield %newval : i32 + } + // CHECK: acc.atomic.update %[[XBOOL]] : memref // CHECK-NEXT: (%[[XVAL:.*]]: i1): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.and %[[XVAL]], %[[EXPRBOOL]] : i1 @@ -1902,6 +1927,17 @@ func.func @acc_atomic_capture(%v: memref, %x: memref, %expr: i32) { acc.atomic.write %x = %expr : memref, i32 } + // CHECK: %[[IFCOND1:.*]] = arith.constant true + // CHECK-NEXT: acc.atomic.capture if(%[[IFCOND1]]) { + // CHECK-NEXT: acc.atomic.read %[[v]] = %[[x]] : memref, memref, i32 + // CHECK-NEXT: acc.atomic.write %[[x]] = %[[expr]] : memref, i32 + // CHECK-NEXT: } + %ifCond = arith.constant true + acc.atomic.capture if (%ifCond) { + acc.atomic.read %v = %x : memref, memref, i32 + acc.atomic.write %x = %expr : memref, i32 + } + return }