-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Flang][OpenMP] Increase detection capability for requires usm (and others) #162971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…thers) Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.
@llvm/pr-subscribers-flang-fir-hlfir Author: None (agozillon) ChangesCurrently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively. Full diff: https://github.com/llvm/llvm-project/pull/162971.diff 2 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// - Define module variables and OpenMP/OpenACC declarative constructs so
// they are available before lowering any function that may use them.
bool hasMainProgram = false;
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+ llvm::SmallVector<const Fortran::semantics::Symbol *>
+ globalOmpRequiresSymbols;
createBuilderOutsideOfFuncOpAndDo([&]() {
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = f.getScope().symbol();
+ globalOmpRequiresSymbols.push_back(f.getScope().symbol());
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
- &unit))
+ &unit)) {
declareFunction(*f);
+ globalOmpRequiresSymbols.push_back(
+ f->getScope().symbol());
+ }
+ globalOmpRequiresSymbols.push_back(m.getScope().symbol());
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = b.symTab.symbol();
+ globalOmpRequiresSymbols.push_back(b.symTab.symbol());
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::Coarray));
});
- finalizeOpenMPLowering(globalOmpRequiresSymbol);
+ finalizeOpenMPLowering(globalOmpRequiresSymbols);
}
/// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &globalOmpRequiresSymbol) {
if (!ompDeferredDeclareTarget.empty()) {
bool deferredDeviceFuncFound =
Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
// Set the module attribute related to OpenMP requires directives
- if (ompDeviceCodeFound)
- Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
- globalOmpRequiresSymbol);
+ if (ompDeviceCodeFound) {
+ for (auto sym : globalOmpRequiresSymbol)
+ Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+ }
}
/// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK: module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+ implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program
|
@llvm/pr-subscribers-flang-openmp Author: None (agozillon) ChangesCurrently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively. Full diff: https://github.com/llvm/llvm-project/pull/162971.diff 2 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// - Define module variables and OpenMP/OpenACC declarative constructs so
// they are available before lowering any function that may use them.
bool hasMainProgram = false;
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+ llvm::SmallVector<const Fortran::semantics::Symbol *>
+ globalOmpRequiresSymbols;
createBuilderOutsideOfFuncOpAndDo([&]() {
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = f.getScope().symbol();
+ globalOmpRequiresSymbols.push_back(f.getScope().symbol());
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
- &unit))
+ &unit)) {
declareFunction(*f);
+ globalOmpRequiresSymbols.push_back(
+ f->getScope().symbol());
+ }
+ globalOmpRequiresSymbols.push_back(m.getScope().symbol());
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = b.symTab.symbol();
+ globalOmpRequiresSymbols.push_back(b.symTab.symbol());
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::Coarray));
});
- finalizeOpenMPLowering(globalOmpRequiresSymbol);
+ finalizeOpenMPLowering(globalOmpRequiresSymbols);
}
/// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &globalOmpRequiresSymbol) {
if (!ompDeferredDeclareTarget.empty()) {
bool deferredDeviceFuncFound =
Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
// Set the module attribute related to OpenMP requires directives
- if (ompDeviceCodeFound)
- Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
- globalOmpRequiresSymbol);
+ if (ompDeviceCodeFound) {
+ for (auto sym : globalOmpRequiresSymbol)
+ Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+ }
}
/// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK: module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+ implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comment.
globalOmpRequiresSymbol); | ||
if (ompDeviceCodeFound) { | ||
for (auto sym : globalOmpRequiresSymbol) | ||
Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we optimise this by collecting the ModuleOps
in a DenseSet
and then fire the genOpenMPRequires
calls to prevent duplicate calls for the same Module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misunderstanding what you mean or how bridge works, but won't we only ever invoke this on a single module (the one bridge is currently lowering) at a time and it'll invoke this at the end of that modules lowering on the requires symbols held within the currently being lowered file as part of the finalization process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm not too familiar with how bridge works as well. I am just going by the loop for (auto sym : globalOmpRequiresSymbol)
. If sym1
and sym2
are from the same module, do we need two calls to genOpenMPRequires
or can we just have one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we'd need to invoke it per symbol we gather, as the symbols are from various different scopes within the lowered file that could but don't necessarily contain a requires directive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, thanks for clearing that.
Thank you Andrew for picking up this work. Were you aware of #77082? Module support was a known missing feature that I intended to address by collecting that information from |
This should be handled in the semantic analysis: the requirement flags from the used module(s) should be added to the user. This would also allow us to detect conflicting requirements. |
Correct. When we read a module file, we call ResolveNames, which should put the requirement flags on the module symbol. The scope of that symbol is then returned from the Read function. When we handle UseStmt, we should merge the flags from the module symbol into the current program unit symbol. Then we won't need any of the changes in this PR. |
I'm actually working on handling the optional boolean argument to the requirement clauses, which is somewhat orthogonal to this work. Let me know if you're planning to continue working on this in the near future. If so, I can work on something else in the meantime and wait for you to finish your part. |
I don't really mind, I'd just really love for this feature to be a thing in the near future as it's not ideal for users to specify requires usm or whatever feature and for the compiler to ignore it and generate incorrect IR when we have the capability to generate the correct IR :-) |
Going to close this PR at the moment as I think Krzysztof is planning to pick it up in his work! And go down the more correct and robust route that Sergio originally began. |
Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere.
This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.