Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

mlir::acc::AtomicReadOp::create(firOpBuilder, loc, fromAddress, toAddress,
mlir::TypeAttr::get(elementType));
mlir::TypeAttr::get(elementType),
/*IfCond=*/mlir::Value{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Here and below - should it be ifCond instead of IfCond?

}

/// Used to generate atomic.write operation which is created in existing
Expand All @@ -347,7 +348,8 @@ genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter,
rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
firOpBuilder.restoreInsertionPoint(insertionPoint);

mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr);
mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr,
/*IfCond=*/mlir::Value{});
}

/// Used to generate atomic.update operation which is created in existing
Expand Down Expand Up @@ -463,7 +465,8 @@ static inline void genAtomicUpdateStatement(

mlir::Operation *atomicUpdateOp = nullptr;
atomicUpdateOp =
mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr);
mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr,
/*IfCond=*/mlir::Value{});

llvm::SmallVector<mlir::Type> varTys = {varType};
llvm::SmallVector<mlir::Location> locs = {currentLocation};
Expand Down Expand Up @@ -588,7 +591,9 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();

mlir::Operation *atomicCaptureOp = nullptr;
atomicCaptureOp = mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc);
atomicCaptureOp =
mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc,
/*IfCond=*/mlir::Value{});

firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
Expand Down
30 changes: 25 additions & 5 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2787,10 +2787,16 @@ def AtomicReadOp : OpenACC_Op<"atomic.read", [AtomicReadOpInterface]> {

let arguments = (ins OpenACC_PointerLikeType:$x,
OpenACC_PointerLikeType:$v,
TypeAttr:$element_type);
TypeAttr:$element_type,
Optional<I1>:$ifCond
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe move it up to line above?

let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
)
$v `=` $x
`:` type($v) `,` type($x) `,` $element_type attr-dict
`:` type($v) `,` type($x) `,` $element_type
attr-dict
}];
let hasVerifier = 1;
}
Expand All @@ -2809,8 +2815,12 @@ def AtomicWriteOp : OpenACC_Op<"atomic.write",[AtomicWriteOpInterface]> {
}];

let arguments = (ins OpenACC_PointerLikeType:$x,
AnyType:$expr);
AnyType:$expr,
Optional<I1>:$ifCond);
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
)
$x `=` $expr
`:` type($x) `,` type($expr)
attr-dict
Expand Down Expand Up @@ -2850,10 +2860,15 @@ def AtomicUpdateOp : OpenACC_Op<"atomic.update",

let arguments = (ins Arg<OpenACC_PointerLikeType,
"Address of variable to be updated",
[MemRead, MemWrite]>:$x);
[MemRead, MemWrite]>:$x,
Optional<I1>:$ifCond);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mismatched indentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I missed you don't consider ins here for the indent, so I think I got this right now.

let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
$x `:` type($x) $region attr-dict
oilist(
`if` `(` $ifCond `)`
)
$x `:` type($x)
$region attr-dict
}];
let hasVerifier = 1;
let hasRegionVerifier = 1;
Expand Down Expand Up @@ -2896,8 +2911,13 @@ def AtomicCaptureOp : OpenACC_Op<"atomic.capture",

}];

let arguments = (ins Optional<I1>:$ifCond);

let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
)
$region attr-dict
}];
let hasRegionVerifier = 1;
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3858,7 +3858,8 @@ LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
}

if (Value writeVal = op.getWriteOpVal()) {
rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
op.getIfCond());
return success();
}

Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/OpenACC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,12 @@ acc.set default_async(%i32Value : i32)
func.func @acc_atomic_read(%v: memref<i32>, %x: memref<i32>) {
// CHECK: acc.atomic.read %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
acc.atomic.read %v = %x : memref<i32>, memref<i32>, i32

// CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true
// CHECK-NEXT: acc.atomic.read if(%[[IFCOND1]]) %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
%ifCond = arith.constant true
acc.atomic.read if(%ifCond) %v = %x : memref<i32>, memref<i32>, i32

return
}

Expand All @@ -1776,6 +1782,12 @@ func.func @acc_atomic_read(%v: memref<i32>, %x: memref<i32>) {
func.func @acc_atomic_write(%addr : memref<i32>, %val : i32) {
// CHECK: acc.atomic.write %[[ADDR]] = %[[VAL]] : memref<i32>, i32
acc.atomic.write %addr = %val : memref<i32>, i32

// CHECK-NEXT: %[[IFCOND1:.*]] = arith.constant true
// CHECK-NEXT: acc.atomic.write if(%[[IFCOND1]]) %[[ADDR]] = %[[VAL]] : memref<i32>, i32
%ifCond = arith.constant true
acc.atomic.write if(%ifCond) %addr = %val : memref<i32>, i32

return
}

Expand All @@ -1793,6 +1805,19 @@ func.func @acc_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
%newval = llvm.add %xval, %expr : i32
acc.yield %newval : i32
}

// CHECK: %[[IFCOND1:.*]] = arith.constant true
// CHECK-NEXT: acc.atomic.update if(%[[IFCOND1]]) %[[X]] : memref<i32>
// CHECK-NEXT: (%[[XVAL:.*]]: i32):
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.add %[[XVAL]], %[[EXPR]] : i32
// CHECK-NEXT: acc.yield %[[NEWVAL]] : i32
%ifCond = arith.constant true
acc.atomic.update if (%ifCond) %x : memref<i32> {
^bb0(%xval: i32):
%newval = llvm.add %xval, %expr : i32
acc.yield %newval : i32
}

// CHECK: acc.atomic.update %[[XBOOL]] : memref<i1>
// CHECK-NEXT: (%[[XVAL:.*]]: i1):
// CHECK-NEXT: %[[NEWVAL:.*]] = llvm.and %[[XVAL]], %[[EXPRBOOL]] : i1
Expand Down Expand Up @@ -1902,6 +1927,17 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
acc.atomic.write %x = %expr : memref<i32>, i32
}

// CHECK: %[[IFCOND1:.*]] = arith.constant true
// CHECK-NEXT: acc.atomic.capture if(%[[IFCOND1]]) {
// CHECK-NEXT: acc.atomic.read %[[v]] = %[[x]] : memref<i32>, memref<i32>, i32
// CHECK-NEXT: acc.atomic.write %[[x]] = %[[expr]] : memref<i32>, i32
// CHECK-NEXT: }
%ifCond = arith.constant true
acc.atomic.capture if (%ifCond) {
acc.atomic.read %v = %x : memref<i32>, memref<i32>, i32
acc.atomic.write %x = %expr : memref<i32>, i32
}

return
}

Expand Down