Skip to content

Commit a13f59d

Browse files
jeanPerieraokblast
authored andcommitted
[flang][OpenACC] remap common block member symbols (llvm#163752)
1 parent 1cbf220 commit a13f59d

File tree

2 files changed

+90
-8
lines changed

2 files changed

+90
-8
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,23 @@ static void processDoLoopBounds(
23662366
}
23672367
}
23682368

2369+
static void remapCommonBlockMember(
2370+
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
2371+
const Fortran::semantics::Symbol &member,
2372+
mlir::Value newCommonBlockBaseAddress,
2373+
const Fortran::semantics::Symbol &commonBlockSymbol,
2374+
llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &seenSymbols) {
2375+
if (seenSymbols.contains(&member))
2376+
return;
2377+
mlir::Value accMemberValue = Fortran::lower::genCommonBlockMember(
2378+
converter, loc, member, newCommonBlockBaseAddress,
2379+
commonBlockSymbol.size());
2380+
fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(member);
2381+
fir::ExtendedValue accExv = fir::substBase(hostExv, accMemberValue);
2382+
converter.bindSymbol(member, accExv);
2383+
seenSymbols.insert(&member);
2384+
}
2385+
23692386
/// Remap symbols that appeared in OpenACC data clauses to use the results of
23702387
/// the corresponding data operations. This allows isolating symbol accesses
23712388
/// inside the OpenACC region from accesses in the host and other regions while
@@ -2391,14 +2408,39 @@ static void remapDataOperandSymbols(
23912408
builder.setInsertionPointToStart(&regionOp.getRegion().front());
23922409
llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols;
23932410
mlir::IRMapping mapper;
2411+
mlir::Location loc = regionOp.getLoc();
23942412
for (auto [value, symbol] : dataOperandSymbolPairs) {
2395-
2396-
// If A symbol appears on several data clause, just map it to the first
2413+
// If a symbol appears on several data clause, just map it to the first
23972414
// result (all data operations results for a symbol are pointing same
23982415
// memory, so it does not matter which one is used).
23992416
if (seenSymbols.contains(&symbol.get()))
24002417
continue;
24012418
seenSymbols.insert(&symbol.get());
2419+
// When a common block appears in a directive, remap its members.
2420+
// Note: this will instantiate all common block members even if they are not
2421+
// used inside the region. If hlfir.declare DCE is not made possible, this
2422+
// could be improved to reduce IR noise.
2423+
if (const auto *commonBlock = symbol->template detailsIf<
2424+
Fortran::semantics::CommonBlockDetails>()) {
2425+
const Fortran::semantics::Scope &commonScope = symbol->owner();
2426+
if (commonScope.equivalenceSets().empty()) {
2427+
for (auto member : commonBlock->objects())
2428+
remapCommonBlockMember(converter, loc, *member, value, *symbol,
2429+
seenSymbols);
2430+
} else {
2431+
// Objects equivalenced with common block members still belong to the
2432+
// common block storage even if they are not part of the common block
2433+
// declaration. The easiest and most robust way to find all symbols
2434+
// belonging to the common block is to loop through the scope symbols
2435+
// and check if they belong to the common.
2436+
for (const auto &scopeSymbol : commonScope)
2437+
if (Fortran::semantics::FindCommonBlockContaining(
2438+
*scopeSymbol.second) == &symbol.get())
2439+
remapCommonBlockMember(converter, loc, *scopeSymbol.second, value,
2440+
*symbol, seenSymbols);
2441+
}
2442+
continue;
2443+
}
24022444
std::optional<fir::FortranVariableOpInterface> hostDef =
24032445
symbolMap.lookupVariableDefinition(symbol);
24042446
assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) &&
@@ -2415,10 +2457,8 @@ static void remapDataOperandSymbols(
24152457
"box type mismatch between compute region variable and "
24162458
"hlfir.declare input unexpected");
24172459
if (Fortran::semantics::IsOptional(symbol))
2418-
TODO(regionOp.getLoc(),
2419-
"remapping OPTIONAL symbol in OpenACC compute region");
2420-
auto rawValue =
2421-
fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value);
2460+
TODO(loc, "remapping OPTIONAL symbol in OpenACC compute region");
2461+
auto rawValue = fir::BoxAddrOp::create(builder, loc, hostType, value);
24222462
mapper.map(hostInput, rawValue);
24232463
} else {
24242464
assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
@@ -2430,8 +2470,7 @@ static void remapDataOperandSymbols(
24302470
assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) &&
24312471
"compute region variable and host variable should both be raw "
24322472
"addresses");
2433-
mlir::Value cast =
2434-
builder.createConvert(regionOp.getLoc(), hostType, value);
2473+
mlir::Value cast = builder.createConvert(loc, hostType, value);
24352474
mapper.map(hostInput, cast);
24362475
}
24372476
if (mlir::Value dummyScope = hostDeclare.getDummyScope()) {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
! Test remapping of common blocks appearing in OpenACC data directives.
2+
3+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
4+
5+
subroutine test
6+
real :: x(100), y(100), overlap1(100), overlap2(100)
7+
equivalence (x(50), overlap1)
8+
equivalence (x(40), overlap2)
9+
common /comm/ x, y
10+
!$acc declare link(/comm/)
11+
!$acc parallel loop copyin(/comm/)
12+
do i = 1, 100
13+
x(i) = overlap1(i)*2+ overlap2(i)
14+
enddo
15+
end subroutine
16+
! CHECK-LABEL: func.func @_QPtest() {
17+
! CHECK: %[[ADDRESS_OF_0:.*]] = fir.address_of(@comm_)
18+
! CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ADDRESS_OF_0]] : !fir.ref<!fir.array<800xi8>>) -> !fir.ref<!fir.array<800xi8>> {name = "comm"}
19+
! CHECK: acc.parallel combined(loop) dataOperands(%[[COPYIN_0]] : !fir.ref<!fir.array<800xi8>>) {
20+
! CHECK: %[[CONSTANT_8:.*]] = arith.constant 196 : index
21+
! CHECK: %[[COORDINATE_OF_4:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
22+
! CHECK: %[[CONVERT_4:.*]] = fir.convert %[[COORDINATE_OF_4]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
23+
! CHECK: %[[SHAPE_4:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
24+
! CHECK: %[[DECLARE_5:.*]]:2 = hlfir.declare %[[CONVERT_4]](%[[SHAPE_4]]) storage(%[[COPYIN_0]][196]) {uniq_name = "_QFtestEoverlap1"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
25+
! CHECK: %[[CONSTANT_9:.*]] = arith.constant 156 : index
26+
! CHECK: %[[COORDINATE_OF_5:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
27+
! CHECK: %[[CONVERT_5:.*]] = fir.convert %[[COORDINATE_OF_5]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
28+
! CHECK: %[[SHAPE_5:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
29+
! CHECK: %[[DECLARE_6:.*]]:2 = hlfir.declare %[[CONVERT_5]](%[[SHAPE_5]]) storage(%[[COPYIN_0]][156]) {uniq_name = "_QFtestEoverlap2"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
30+
! CHECK: %[[CONSTANT_10:.*]] = arith.constant 0 : index
31+
! CHECK: %[[COORDINATE_OF_6:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
32+
! CHECK: %[[CONVERT_6:.*]] = fir.convert %[[COORDINATE_OF_6]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
33+
! CHECK: %[[SHAPE_6:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
34+
! CHECK: %[[DECLARE_7:.*]]:2 = hlfir.declare %[[CONVERT_6]](%[[SHAPE_6]]) storage(%[[COPYIN_0]][0]) {uniq_name = "_QFtestEx"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
35+
! CHECK: %[[CONSTANT_11:.*]] = arith.constant 400 : index
36+
! CHECK: %[[COORDINATE_OF_7:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
37+
! CHECK: %[[CONVERT_7:.*]] = fir.convert %[[COORDINATE_OF_7]] : (!fir.ref<i8>) -> !fir.ref<!fir.array<100xf32>>
38+
! CHECK: %[[SHAPE_7:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
39+
! CHECK: %[[DECLARE_8:.*]]:2 = hlfir.declare %[[CONVERT_7]](%[[SHAPE_7]]) storage(%[[COPYIN_0]][400]) {uniq_name = "_QFtestEy"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
40+
! CHECK: acc.loop combined(parallel)
41+
! CHECK: %[[DESIGNATE_0:.*]] = hlfir.designate %[[DECLARE_5]]#0
42+
! CHECK: %[[DESIGNATE_1:.*]] = hlfir.designate %[[DECLARE_6]]#0
43+
! CHECK: %[[DESIGNATE_2:.*]] = hlfir.designate %[[DECLARE_7]]#0

0 commit comments

Comments
 (0)