Skip to content

Commit 6ca7f87

Browse files
committed
Fixed atomic capture cases with atomic update inside.
1 parent b895e18 commit 6ca7f87

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,7 +2816,8 @@ static void genAtomicUpdateStatement(
28162816
const parser::Expr &assignmentStmtExpr,
28172817
const parser::OmpAtomicClauseList *leftHandClauseList,
28182818
const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
2819-
mlir::Operation *atomicCaptureOp = nullptr) {
2819+
mlir::Operation *atomicCaptureOp = nullptr,
2820+
lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
28202821
// Generate `atomic.update` operation for atomic assignment statements
28212822
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
28222823
mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2890,15 +2891,24 @@ static void genAtomicUpdateStatement(
28902891
},
28912892
assignmentStmtExpr.u);
28922893
lower::StatementContext nonAtomicStmtCtx;
2894+
lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
28932895
if (!nonAtomicSubExprs.empty()) {
28942896
// Generate non atomic part before all the atomic operations.
28952897
auto insertionPoint = firOpBuilder.saveInsertionPoint();
2896-
if (atomicCaptureOp)
2898+
if (atomicCaptureOp) {
2899+
assert(atomicCaptureStmtCtx && "must specify statement context");
28972900
firOpBuilder.setInsertionPoint(atomicCaptureOp);
2901+
// Any clean-ups associated with the expression lowering
2902+
// must also be generated outside of the atomic update operation
2903+
// and after the atomic capture operation.
2904+
// The atomicCaptureStmtCtx will be finalized at the end
2905+
// of the atomic capture operation generation.
2906+
stmtCtxPtr = atomicCaptureStmtCtx;
2907+
}
28982908
mlir::Value nonAtomicVal;
28992909
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
29002910
nonAtomicVal = fir::getBase(converter.genExprValue(
2901-
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
2911+
currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
29022912
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
29032913
}
29042914
if (atomicCaptureOp)
@@ -3238,7 +3248,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
32383248
genAtomicUpdateStatement(
32393249
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
32403250
/*leftHandClauseList=*/nullptr,
3241-
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
3251+
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
32423252
} else {
32433253
// Atomic capture construct is of the form [capture-stmt, write-stmt]
32443254
firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -3284,7 +3294,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
32843294
genAtomicUpdateStatement(
32853295
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
32863296
/*leftHandClauseList=*/nullptr,
3287-
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
3297+
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
32883298

32893299
if (stmt1VarType != stmt2VarType) {
32903300
mlir::Value alloca;

flang/test/Lower/OpenMP/atomic-capture.f90

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,54 @@ subroutine pointers_in_atomic_capture()
102102
! are generated after the omp.atomic.capture operation:
103103
! CHECK-LABEL: func.func @_QPfunc_call_cleanup(
104104
subroutine func_call_cleanup(x, v, vv)
105+
interface
106+
integer function func(x)
107+
integer :: x
108+
end function func
109+
end interface
105110
integer :: x, v, vv
106111

107112
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
108-
! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
109-
! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
113+
! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
110114
! CHECK: omp.atomic.capture {
111-
! CHECK: omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
112-
! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
115+
! CHECK: omp.atomic.read %[[VAL_1:.*]]#0 = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
116+
! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_8]] : !fir.ref<i32>, i32
113117
! CHECK: }
114118
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
115119
!$omp atomic capture
116120
v = x
117121
x = func(vv + 1)
118122
!$omp end atomic
123+
124+
! CHECK: %[[VAL_12:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
125+
! CHECK: %[[VAL_13:.*]] = fir.call @_QPfunc(%[[VAL_12]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
126+
! CHECK: omp.atomic.capture {
127+
! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
128+
! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
129+
! CHECK: ^bb0(%[[VAL_14:.*]]: i32):
130+
! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32
131+
! CHECK: omp.yield(%[[VAL_15]] : i32)
132+
! CHECK: }
133+
! CHECK: }
134+
! CHECK: hlfir.end_associate %[[VAL_12]]#1, %[[VAL_12]]#2 : !fir.ref<i32>, i1
135+
!$omp atomic capture
136+
v = x
137+
x = func(vv + 1) + x
138+
!$omp end atomic
139+
140+
! CHECK: %[[VAL_19:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
141+
! CHECK: %[[VAL_20:.*]] = fir.call @_QPfunc(%[[VAL_19]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
142+
! CHECK: omp.atomic.capture {
143+
! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
144+
! CHECK: ^bb0(%[[VAL_21:.*]]: i32):
145+
! CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i32
146+
! CHECK: omp.yield(%[[VAL_22]] : i32)
147+
! CHECK: }
148+
! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
149+
! CHECK: }
150+
! CHECK: hlfir.end_associate %[[VAL_19]]#1, %[[VAL_19]]#2 : !fir.ref<i32>, i1
151+
!$omp atomic capture
152+
x = func(vv + 1) + x
153+
v = x
154+
!$omp end atomic
119155
end subroutine func_call_cleanup

0 commit comments

Comments
 (0)