Skip to content

Commit a509073

Browse files
committed
[flang][acc] Generate acc.copyout for the reduction clause on compute constructs
1 parent 9cb7545 commit a509073

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2676,7 +2676,8 @@ static Op createComputeOp(
26762676
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
26772677
copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
26782678
createEntryOperands, nocreateEntryOperands, presentEntryOperands,
2679-
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
2679+
reductionEntryOperands, dataClauseOperands, numGangs, numWorkers,
2680+
vectorLength, async;
26802681
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
26812682
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
26822683
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
@@ -2912,9 +2913,12 @@ static Op createComputeOp(
29122913
// combined construct implies a copy clause so issue an implicit copy
29132914
// instead.
29142915
if (!combinedConstructs) {
2916+
auto crtDataStart = reductionOperands.size();
29152917
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
29162918
reductionOperands, reductionRecipes, async,
29172919
asyncDeviceTypes, asyncOnlyDeviceTypes);
2920+
reductionEntryOperands.append(reductionOperands.begin() + crtDataStart,
2921+
reductionOperands.end());
29182922
} else {
29192923
auto crtDataStart = dataClauseOperands.size();
29202924
genDataOperandOperations<mlir::acc::CopyinOp>(
@@ -3038,6 +3042,8 @@ static Op createComputeOp(
30383042
builder, nocreateEntryOperands, /*structured=*/true);
30393043
genDataExitOperations<mlir::acc::PresentOp, mlir::acc::DeleteOp>(
30403044
builder, presentEntryOperands, /*structured=*/true);
3045+
genDataExitOperations<mlir::acc::ReductionOp, mlir::acc::CopyoutOp>(
3046+
builder, reductionEntryOperands, /*structured=*/true);
30413047

30423048
builder.restoreInsertionPoint(insPt);
30433049
return computeOp;

flang/test/Lower/OpenACC/acc-reduction-unwrap-defaultbounds.f90

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ subroutine acc_reduction_iand()
10011001
! CHECK-LABEL: func.func @_QPacc_reduction_iand()
10021002
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
10031003
! CHECK: acc.parallel reduction(@reduction_iand_ref_i32 -> %[[RED]] : !fir.ref<i32>)
1004+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}
10041005

10051006
subroutine acc_reduction_ior()
10061007
integer :: i
@@ -1011,6 +1012,7 @@ subroutine acc_reduction_ior()
10111012
! CHECK-LABEL: func.func @_QPacc_reduction_ior()
10121013
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
10131014
! CHECK: acc.parallel reduction(@reduction_ior_ref_i32 -> %[[RED]] : !fir.ref<i32>)
1015+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}
10141016

10151017
subroutine acc_reduction_ieor()
10161018
integer :: i
@@ -1021,6 +1023,7 @@ subroutine acc_reduction_ieor()
10211023
! CHECK-LABEL: func.func @_QPacc_reduction_ieor()
10221024
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
10231025
! CHECK: acc.parallel reduction(@reduction_xor_ref_i32 -> %[[RED]] : !fir.ref<i32>)
1026+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}
10241027

10251028
subroutine acc_reduction_and()
10261029
logical :: l
@@ -1033,6 +1036,7 @@ subroutine acc_reduction_and()
10331036
! CHECK: %[[DECLL:.*]]:2 = hlfir.declare %[[L]]
10341037
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[DECLL]]#0 : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
10351038
! CHECK: acc.parallel reduction(@reduction_land_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
1039+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}
10361040

10371041
subroutine acc_reduction_or()
10381042
logical :: l
@@ -1043,6 +1047,7 @@ subroutine acc_reduction_or()
10431047
! CHECK-LABEL: func.func @_QPacc_reduction_or()
10441048
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
10451049
! CHECK: acc.parallel reduction(@reduction_lor_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
1050+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}
10461051

10471052
subroutine acc_reduction_eqv()
10481053
logical :: l
@@ -1053,6 +1058,7 @@ subroutine acc_reduction_eqv()
10531058
! CHECK-LABEL: func.func @_QPacc_reduction_eqv()
10541059
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
10551060
! CHECK: acc.parallel reduction(@reduction_eqv_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
1061+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}
10561062

10571063
subroutine acc_reduction_neqv()
10581064
logical :: l
@@ -1063,6 +1069,7 @@ subroutine acc_reduction_neqv()
10631069
! CHECK-LABEL: func.func @_QPacc_reduction_neqv()
10641070
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
10651071
! CHECK: acc.parallel reduction(@reduction_neqv_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
1072+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}
10661073

10671074
subroutine acc_reduction_add_cmplx()
10681075
complex :: c
@@ -1073,6 +1080,7 @@ subroutine acc_reduction_add_cmplx()
10731080
! CHECK-LABEL: func.func @_QPacc_reduction_add_cmplx()
10741081
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<complex<f32>>) -> !fir.ref<complex<f32>> {name = "c"}
10751082
! CHECK: acc.parallel reduction(@reduction_add_ref_z32 -> %[[RED]] : !fir.ref<complex<f32>>)
1083+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<complex<f32>>) to varPtr(%{{.*}} : !fir.ref<complex<f32>>) {dataClause = #acc<data_clause acc_reduction>, name = "c"}
10761084

10771085
subroutine acc_reduction_mul_cmplx()
10781086
complex :: c
@@ -1083,6 +1091,7 @@ subroutine acc_reduction_mul_cmplx()
10831091
! CHECK-LABEL: func.func @_QPacc_reduction_mul_cmplx()
10841092
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<complex<f32>>) -> !fir.ref<complex<f32>> {name = "c"}
10851093
! CHECK: acc.parallel reduction(@reduction_mul_ref_z32 -> %[[RED]] : !fir.ref<complex<f32>>)
1094+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<complex<f32>>) to varPtr(%{{.*}} : !fir.ref<complex<f32>>) {dataClause = #acc<data_clause acc_reduction>, name = "c"}
10861095

10871096
subroutine acc_reduction_add_alloc()
10881097
integer, allocatable :: i
@@ -1098,6 +1107,7 @@ subroutine acc_reduction_add_alloc()
10981107
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
10991108
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<i32>) -> !fir.heap<i32> {name = "i"}
11001109
! CHECK: acc.parallel reduction(@reduction_add_heap_i32 -> %[[RED]] : !fir.heap<i32>)
1110+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.heap<i32>) to varPtr(%[[BOX_ADDR]] : !fir.heap<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}
11011111

11021112
subroutine acc_reduction_add_pointer(i)
11031113
integer, pointer :: i
@@ -1112,6 +1122,7 @@ subroutine acc_reduction_add_pointer(i)
11121122
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
11131123
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<i32>) -> !fir.ptr<i32> {name = "i"}
11141124
! CHECK: acc.parallel reduction(@reduction_add_ptr_i32 -> %[[RED]] : !fir.ptr<i32>)
1125+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ptr<i32>) to varPtr(%[[BOX_ADDR]] : !fir.ptr<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}
11151126

11161127
subroutine acc_reduction_add_static_slice(a)
11171128
integer :: a(100)
@@ -1129,6 +1140,7 @@ subroutine acc_reduction_add_static_slice(a)
11291140
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[C100]] : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
11301141
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[DECLARG0]]#0 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a(11:20)"}
11311142
! CHECK: acc.parallel reduction(@reduction_add_section_lb10.ub19_ref_100xi32 -> %[[RED]] : !fir.ref<!fir.array<100xi32>>)
1143+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECLARG0]]#0 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a(11:20)"}
11321144

11331145
subroutine acc_reduction_add_dynamic_extent_add(a)
11341146
integer :: a(:)
@@ -1141,6 +1153,7 @@ subroutine acc_reduction_add_dynamic_extent_add(a)
11411153
! CHECK: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
11421154
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xi32>> {name = "a"}
11431155
! CHECK: acc.parallel reduction(@reduction_add_box_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)
1156+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}
11441157

11451158
subroutine acc_reduction_add_assumed_shape_max(a)
11461159
real :: a(:)
@@ -1153,6 +1166,7 @@ subroutine acc_reduction_add_assumed_shape_max(a)
11531166
! CHECK: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
11541167
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
11551168
! CHECK: acc.parallel reduction(@reduction_max_box_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {
1169+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}
11561170

11571171
subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
11581172
integer :: a(:)
@@ -1167,6 +1181,7 @@ subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
11671181
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL]]#0 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
11681182
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xi32>> {name = "a(2:4)"}
11691183
! CHECK: acc.parallel reduction(@reduction_add_section_lb1.ub3_box_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)
1184+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a(2:4)"}
11701185

11711186
subroutine acc_reduction_add_allocatable(a)
11721187
real, allocatable :: a(:)
@@ -1180,8 +1195,9 @@ subroutine acc_reduction_add_allocatable(a)
11801195
! CHECK: %[[BOX:.*]] = fir.load %[[DECL]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
11811196
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
11821197
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
1183-
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%{{[0-9]+}}) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
1198+
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
11841199
! CHECK: acc.parallel reduction(@reduction_max_box_heap_Uxf32 -> %[[RED]] : !fir.heap<!fir.array<?xf32>>)
1200+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}
11851201

11861202
subroutine acc_reduction_add_pointer_array(a)
11871203
real, pointer :: a(:)
@@ -1197,6 +1213,7 @@ subroutine acc_reduction_add_pointer_array(a)
11971213
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
11981214
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
11991215
! CHECK: acc.parallel reduction(@reduction_max_box_ptr_Uxf32 -> %[[RED]] : !fir.ptr<!fir.array<?xf32>>)
1216+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}
12001217

12011218
subroutine acc_reduction_max_dynamic_extent_max(a, n)
12021219
integer :: n
@@ -1211,3 +1228,4 @@ subroutine acc_reduction_max_dynamic_extent_max(a, n)
12111228
! CHECK: %[[ADDR:.*]] = fir.box_addr %[[DECL_A]]#0 : (!fir.box<!fir.array<?x?xf32>>) -> !fir.ref<!fir.array<?x?xf32>>
12121229
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[ADDR]] : !fir.ref<!fir.array<?x?xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
12131230
! CHECK: acc.parallel reduction(@reduction_max_box_UxUxf32 -> %[[RED]] : !fir.ref<!fir.array<?x?xf32>>)
1231+
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?x?xf32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?x?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}

0 commit comments

Comments
 (0)