-
Couldn't load subscription status.
- Fork 15k
[flang][OpenACC] remap common block member symbols #163752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-openacc Author: None (jeanPerier) ChangesFull diff: https://github.com/llvm/llvm-project/pull/163752.diff 2 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index cfb18914e8126..5bd795d3c719c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2250,6 +2250,23 @@ static void processDoLoopBounds(
}
}
+static void remapCommonBlockMember(
+ Fortran::lower::AbstractConverter &converter, mlir::Location loc,
+ const Fortran::semantics::Symbol &member,
+ mlir::Value newCommonBlockBaseAddress,
+ const Fortran::semantics::Symbol &commonBlockSymbol,
+ llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &seenSymbols) {
+ if (seenSymbols.contains(&member))
+ return;
+ mlir::Value accMemberValue = Fortran::lower::genCommonBlockMember(
+ converter, loc, member, newCommonBlockBaseAddress,
+ commonBlockSymbol.size());
+ fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(member);
+ fir::ExtendedValue accExv = fir::substBase(hostExv, accMemberValue);
+ converter.bindSymbol(member, accExv);
+ seenSymbols.insert(&member);
+}
+
/// Remap symbols that appeared in OpenACC data clauses to use the results of
/// the corresponding data operations. This allows isolating symbol accesses
/// inside the OpenACC region from accesses in the host and other regions while
@@ -2275,14 +2292,39 @@ static void remapDataOperandSymbols(
builder.setInsertionPointToStart(®ionOp.getRegion().front());
llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols;
mlir::IRMapping mapper;
+ mlir::Location loc = regionOp.getLoc();
for (auto [value, symbol] : dataOperandSymbolPairs) {
-
- // If A symbol appears on several data clause, just map it to the first
+ // If a symbol appears on several data clause, just map it to the first
// result (all data operations results for a symbol are pointing same
// memory, so it does not matter which one is used).
if (seenSymbols.contains(&symbol.get()))
continue;
seenSymbols.insert(&symbol.get());
+ // When a common block appears in a directive, remap its members.
+ // Note: this will instantiate all common block members even if they are not
+ // used inside the region. If hlfir.declare DCE is not made possible, this
+ // could be improved to reduce IR noise.
+ if (const auto *commonBlock = symbol->template detailsIf<
+ Fortran::semantics::CommonBlockDetails>()) {
+ const Fortran::semantics::Scope &commonScope = symbol->owner();
+ if (commonScope.equivalenceSets().empty()) {
+ for (auto member : commonBlock->objects())
+ remapCommonBlockMember(converter, loc, *member, value, *symbol,
+ seenSymbols);
+ } else {
+ // Objects equivalenced with common block members still belong to the
+ // common block storage even if they are not part of the common block
+ // declaration. The easiest and most robust way to find all symbols
+ // belonging to the common block is to loop through the scope symbols
+ // and check if they belong to the common.
+ for (const auto &scopeSymbol : commonScope)
+ if (Fortran::semantics::FindCommonBlockContaining(
+ *scopeSymbol.second) == &symbol.get())
+ remapCommonBlockMember(converter, loc, *scopeSymbol.second, value,
+ *symbol, seenSymbols);
+ }
+ continue;
+ }
std::optional<fir::FortranVariableOpInterface> hostDef =
symbolMap.lookupVariableDefinition(symbol);
assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) &&
@@ -2299,10 +2341,8 @@ static void remapDataOperandSymbols(
"box type mismatch between compute region variable and "
"hlfir.declare input unexpected");
if (Fortran::semantics::IsOptional(symbol))
- TODO(regionOp.getLoc(),
- "remapping OPTIONAL symbol in OpenACC compute region");
- auto rawValue =
- fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value);
+ TODO(loc, "remapping OPTIONAL symbol in OpenACC compute region");
+ auto rawValue = fir::BoxAddrOp::create(builder, loc, hostType, value);
mapper.map(hostInput, rawValue);
} else {
assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
@@ -2314,8 +2354,7 @@ static void remapDataOperandSymbols(
assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) &&
"compute region variable and host variable should both be raw "
"addresses");
- mlir::Value cast =
- builder.createConvert(regionOp.getLoc(), hostType, value);
+ mlir::Value cast = builder.createConvert(loc, hostType, value);
mapper.map(hostInput, cast);
}
if (mlir::Value dummyScope = hostDeclare.getDummyScope()) {
diff --git a/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90 b/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90
new file mode 100644
index 0000000000000..1ab883e0599c0
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90
@@ -0,0 +1,43 @@
+! Test remapping of common blocks appearing in OpenACC data directives.
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine test
+ real :: x(100), y(100), overlap1(100), overlap2(100)
+ equivalence (x(50), overlap1)
+ equivalence (x(40), overlap2)
+ common /comm/ x, y
+ !$acc declare link(/comm/)
+ !$acc parallel loop copyin(/comm/)
+ do i = 1, 100
+ x(i) = overlap1(i)*2+ overlap2(i)
+ enddo
+end subroutine
+! CHECK-LABEL: func.func @_QPtest() {
+! CHECK: %[[ADDRESS_OF_0:.*]] = fir.address_of(@comm_)
+! CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ADDRESS_OF_0]] : !fir.ref<!fir.array<800xi8>>) -> !fir.ref<!fir.array<800xi8>> {name = "comm"}
+! CHECK: acc.parallel combined(loop) dataOperands(%[[COPYIN_0]] : !fir.ref<!fir.array<800xi8>>) {
+! CHECK: %[[CONSTANT_8:.*]] = arith.constant 196 : index
+! CHECK: %[[COORDINATE_OF_4:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_4:.*]] = fir.convert %[[COORDINATE_OF_4]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_4:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_5:.*]]:2 = hlfir.declare %[[CONVERT_4]](%[[SHAPE_4]]) storage(%[[COPYIN_0]][196]) {uniq_name = "_QFtestEoverlap1"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_9:.*]] = arith.constant 156 : index
+! CHECK: %[[COORDINATE_OF_5:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_5:.*]] = fir.convert %[[COORDINATE_OF_5]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_5:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_6:.*]]:2 = hlfir.declare %[[CONVERT_5]](%[[SHAPE_5]]) storage(%[[COPYIN_0]][156]) {uniq_name = "_QFtestEoverlap2"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_10:.*]] = arith.constant 0 : index
+! CHECK: %[[COORDINATE_OF_6:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_6:.*]] = fir.convert %[[COORDINATE_OF_6]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_6:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_7:.*]]:2 = hlfir.declare %[[CONVERT_6]](%[[SHAPE_6]]) storage(%[[COPYIN_0]][0]) {uniq_name = "_QFtestEx"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_11:.*]] = arith.constant 400 : index
+! CHECK: %[[COORDINATE_OF_7:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_7:.*]] = fir.convert %[[COORDINATE_OF_7]] : (!fir.ref<i8>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_7:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_8:.*]]:2 = hlfir.declare %[[CONVERT_7]](%[[SHAPE_7]]) storage(%[[COPYIN_0]][400]) {uniq_name = "_QFtestEy"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
+! CHECK: acc.loop combined(parallel)
+! CHECK: %[[DESIGNATE_0:.*]] = hlfir.designate %[[DECLARE_5]]#0
+! CHECK: %[[DESIGNATE_1:.*]] = hlfir.designate %[[DECLARE_6]]#0
+! CHECK: %[[DESIGNATE_2:.*]] = hlfir.designate %[[DECLARE_7]]#0
|
|
@llvm/pr-subscribers-flang-fir-hlfir Author: None (jeanPerier) ChangesFull diff: https://github.com/llvm/llvm-project/pull/163752.diff 2 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index cfb18914e8126..5bd795d3c719c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2250,6 +2250,23 @@ static void processDoLoopBounds(
}
}
+static void remapCommonBlockMember(
+ Fortran::lower::AbstractConverter &converter, mlir::Location loc,
+ const Fortran::semantics::Symbol &member,
+ mlir::Value newCommonBlockBaseAddress,
+ const Fortran::semantics::Symbol &commonBlockSymbol,
+ llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &seenSymbols) {
+ if (seenSymbols.contains(&member))
+ return;
+ mlir::Value accMemberValue = Fortran::lower::genCommonBlockMember(
+ converter, loc, member, newCommonBlockBaseAddress,
+ commonBlockSymbol.size());
+ fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(member);
+ fir::ExtendedValue accExv = fir::substBase(hostExv, accMemberValue);
+ converter.bindSymbol(member, accExv);
+ seenSymbols.insert(&member);
+}
+
/// Remap symbols that appeared in OpenACC data clauses to use the results of
/// the corresponding data operations. This allows isolating symbol accesses
/// inside the OpenACC region from accesses in the host and other regions while
@@ -2275,14 +2292,39 @@ static void remapDataOperandSymbols(
builder.setInsertionPointToStart(®ionOp.getRegion().front());
llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols;
mlir::IRMapping mapper;
+ mlir::Location loc = regionOp.getLoc();
for (auto [value, symbol] : dataOperandSymbolPairs) {
-
- // If A symbol appears on several data clause, just map it to the first
+ // If a symbol appears on several data clause, just map it to the first
// result (all data operations results for a symbol are pointing same
// memory, so it does not matter which one is used).
if (seenSymbols.contains(&symbol.get()))
continue;
seenSymbols.insert(&symbol.get());
+ // When a common block appears in a directive, remap its members.
+ // Note: this will instantiate all common block members even if they are not
+ // used inside the region. If hlfir.declare DCE is not made possible, this
+ // could be improved to reduce IR noise.
+ if (const auto *commonBlock = symbol->template detailsIf<
+ Fortran::semantics::CommonBlockDetails>()) {
+ const Fortran::semantics::Scope &commonScope = symbol->owner();
+ if (commonScope.equivalenceSets().empty()) {
+ for (auto member : commonBlock->objects())
+ remapCommonBlockMember(converter, loc, *member, value, *symbol,
+ seenSymbols);
+ } else {
+ // Objects equivalenced with common block members still belong to the
+ // common block storage even if they are not part of the common block
+ // declaration. The easiest and most robust way to find all symbols
+ // belonging to the common block is to loop through the scope symbols
+ // and check if they belong to the common.
+ for (const auto &scopeSymbol : commonScope)
+ if (Fortran::semantics::FindCommonBlockContaining(
+ *scopeSymbol.second) == &symbol.get())
+ remapCommonBlockMember(converter, loc, *scopeSymbol.second, value,
+ *symbol, seenSymbols);
+ }
+ continue;
+ }
std::optional<fir::FortranVariableOpInterface> hostDef =
symbolMap.lookupVariableDefinition(symbol);
assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) &&
@@ -2299,10 +2341,8 @@ static void remapDataOperandSymbols(
"box type mismatch between compute region variable and "
"hlfir.declare input unexpected");
if (Fortran::semantics::IsOptional(symbol))
- TODO(regionOp.getLoc(),
- "remapping OPTIONAL symbol in OpenACC compute region");
- auto rawValue =
- fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value);
+ TODO(loc, "remapping OPTIONAL symbol in OpenACC compute region");
+ auto rawValue = fir::BoxAddrOp::create(builder, loc, hostType, value);
mapper.map(hostInput, rawValue);
} else {
assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
@@ -2314,8 +2354,7 @@ static void remapDataOperandSymbols(
assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) &&
"compute region variable and host variable should both be raw "
"addresses");
- mlir::Value cast =
- builder.createConvert(regionOp.getLoc(), hostType, value);
+ mlir::Value cast = builder.createConvert(loc, hostType, value);
mapper.map(hostInput, cast);
}
if (mlir::Value dummyScope = hostDeclare.getDummyScope()) {
diff --git a/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90 b/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90
new file mode 100644
index 0000000000000..1ab883e0599c0
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-data-operands-remapping-common.f90
@@ -0,0 +1,43 @@
+! Test remapping of common blocks appearing in OpenACC data directives.
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine test
+ real :: x(100), y(100), overlap1(100), overlap2(100)
+ equivalence (x(50), overlap1)
+ equivalence (x(40), overlap2)
+ common /comm/ x, y
+ !$acc declare link(/comm/)
+ !$acc parallel loop copyin(/comm/)
+ do i = 1, 100
+ x(i) = overlap1(i)*2+ overlap2(i)
+ enddo
+end subroutine
+! CHECK-LABEL: func.func @_QPtest() {
+! CHECK: %[[ADDRESS_OF_0:.*]] = fir.address_of(@comm_)
+! CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ADDRESS_OF_0]] : !fir.ref<!fir.array<800xi8>>) -> !fir.ref<!fir.array<800xi8>> {name = "comm"}
+! CHECK: acc.parallel combined(loop) dataOperands(%[[COPYIN_0]] : !fir.ref<!fir.array<800xi8>>) {
+! CHECK: %[[CONSTANT_8:.*]] = arith.constant 196 : index
+! CHECK: %[[COORDINATE_OF_4:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_4:.*]] = fir.convert %[[COORDINATE_OF_4]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_4:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_5:.*]]:2 = hlfir.declare %[[CONVERT_4]](%[[SHAPE_4]]) storage(%[[COPYIN_0]][196]) {uniq_name = "_QFtestEoverlap1"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_9:.*]] = arith.constant 156 : index
+! CHECK: %[[COORDINATE_OF_5:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_5:.*]] = fir.convert %[[COORDINATE_OF_5]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_5:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_6:.*]]:2 = hlfir.declare %[[CONVERT_5]](%[[SHAPE_5]]) storage(%[[COPYIN_0]][156]) {uniq_name = "_QFtestEoverlap2"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_10:.*]] = arith.constant 0 : index
+! CHECK: %[[COORDINATE_OF_6:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_6:.*]] = fir.convert %[[COORDINATE_OF_6]] : (!fir.ref<i8>) -> !fir.ptr<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_6:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_7:.*]]:2 = hlfir.declare %[[CONVERT_6]](%[[SHAPE_6]]) storage(%[[COPYIN_0]][0]) {uniq_name = "_QFtestEx"} : (!fir.ptr<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ptr<!fir.array<100xf32>>, !fir.ptr<!fir.array<100xf32>>)
+! CHECK: %[[CONSTANT_11:.*]] = arith.constant 400 : index
+! CHECK: %[[COORDINATE_OF_7:.*]] = fir.coordinate_of %[[COPYIN_0]], %{{.*}} : (!fir.ref<!fir.array<800xi8>>, index) -> !fir.ref<i8>
+! CHECK: %[[CONVERT_7:.*]] = fir.convert %[[COORDINATE_OF_7]] : (!fir.ref<i8>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[SHAPE_7:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECLARE_8:.*]]:2 = hlfir.declare %[[CONVERT_7]](%[[SHAPE_7]]) storage(%[[COPYIN_0]][400]) {uniq_name = "_QFtestEy"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, !fir.ref<!fir.array<800xi8>>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
+! CHECK: acc.loop combined(parallel)
+! CHECK: %[[DESIGNATE_0:.*]] = hlfir.designate %[[DECLARE_5]]#0
+! CHECK: %[[DESIGNATE_1:.*]] = hlfir.designate %[[DECLARE_6]]#0
+! CHECK: %[[DESIGNATE_2:.*]] = hlfir.designate %[[DECLARE_7]]#0
|
|
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.