-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[flang][OpenMP] Improve reduction of Scalar ArrayElement types #163940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[flang][OpenMP] Improve reduction of Scalar ArrayElement types #163940
Conversation
Currently, Flang does not correctly lower ArrayElement's when processing an OpenMP Reduction Clause correctly. Rather than lowering the array element, the whole array will be lowered. This leads to slower performance for the end user in their program. This patch works to rectify this by rewriteing the parse tree while processing semantics. The use of an ArrayElement in an OpenMP Reduction Clause will be identified, and replaced with a temporary both in the reduction clause, and anywhere that array element is used within the respective DoConstruct. Once the DoConstruct has finished, if the ArrayElement has been used within the Do loop, the value of the temporary will be re-assigned to the array element. One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used. The reason for making the change in the parse tree is due to how ArrayElements are lowered. When lowering, the expression of the ArrayElement being used in the reduction is being substitued with the reference to the symbol. In this case, that would be the whole array. By replacing it with a temporary, it removes the issue of lowering a full array as it will be referencing the temporary instead. To address this in lowering would require a major rethink on how a considerable amount of non-OpenMP code is lowered and as such, not deemed the appropriate course of action for this specific case. This process is done after the initial Semantics Pass as to not affect the checking of users original code. If the array element has been replaced, the first pass of semantics will need to be rerun to ensure all TypedExpr's are correctly captured otherwise the lowering will not function correctly. This step is only done if an ArrayElement is replaced. Testing is covered by reduction17.f90. This checks both the parse tree, unparsing and HLFIR to ensure the temproary is being used in the reduction clause and Do loop. Assignment to, and reassignment from the ArrayElement and the Temporary is also considered to ensure this is inserted at the correct location. reduction09.f90 has also been reformatted to rely on FileCheck. As the Parse Tree is changing, the output is different to that of the user, so we can no longer rely on test_symbols.py for this test. The same information is being checked, with test cases that cover using an ArrayElement in the Do loop, and not using the ArrayElement being covered. Array Sections are not affected by this change, only uses of single array elements.
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-flang-semantics Author: Jack Styles (Stylie777) ChangesCurrently, Flang does not correctly lower ArrayElement's when processing an OpenMP Reduction Clause correctly. Rather than lowering the array element, the whole array will be lowered. This leads to slower performance for the end user in their program. This patch works to rectify this by rewriteing the parse tree while processing semantics. The use of an ArrayElement in an OpenMP Reduction Clause will be identified, and replaced with a temporary both in the reduction clause, and anywhere that array element is used within the respective DoConstruct. Once the DoConstruct has finished, if the ArrayElement has been used within the Do loop, the value of the temporary will be re-assigned to the array element. One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used. The reason for making the change in the parse tree is due to how ArrayElements are lowered. When lowering, the expression of the ArrayElement being used in the reduction is being substitued with the reference to the symbol. In this case, that would be the whole array. By replacing it with a temporary, it removes the issue of lowering a full array as it will be referencing the temporary instead. To address this in lowering would require a major rethink on how a considerable amount of non-OpenMP code is lowered and as such, not deemed the appropriate course of action for this specific case. This process is done after the initial Semantics Pass as to not affect the checking of users original code. If the array element has been replaced, the first pass of semantics will need to be rerun to ensure all TypedExpr's are correctly captured otherwise the lowering will not function correctly. This step is only done if an ArrayElement is replaced. Testing is covered by reduction17.f90. This checks both the parse tree, unparsing and HLFIR to ensure the temproary is being used in the reduction clause and Do loop. Assignment to, and reassignment from the ArrayElement and the Temporary is also considered to ensure this is inserted at the correct location. reduction09.f90 has also been reformatted to rely on FileCheck. As the Parse Tree is changing, the output is different to that of the user, so we can no longer rely on test_symbols.py for this test. The same information is being checked, with test cases that cover using an ArrayElement in the Do loop, and not using the ArrayElement being covered. Array Sections are not affected by this change, only uses of single array elements. Patch is 46.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163940.diff 5 Files Affected:
diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 5b7dab309eda7..5379dcdd3d40c 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -95,6 +95,86 @@ class RewriteMutator {
parser::Messages &messages_;
};
+class ReplacementTemp {
+public:
+ ReplacementTemp() {}
+
+ void createTempSymbol(
+ SourceName &source, Scope &scope, SemanticsContext &context);
+ void setOriginalSubscriptInt(
+ std::list<parser::SectionSubscript> §ionSubscript);
+ Symbol *getTempSymbol() { return replacementTempSymbol_; }
+ Symbol *getPrivateReductionSymbol() { return privateReductionSymbol_; }
+ parser::CharBlock getOriginalSource() { return originalSource_; }
+ parser::Name getOriginalName() { return originalName_; }
+ parser::CharBlock getOriginalSubscript() {
+ return originalSubscriptCharBlock_;
+ }
+ Scope *getTempScope() { return tempScope_; }
+ bool isArrayElementReassigned() { return arrayElementReassigned_; }
+ bool isSectionTriplet() { return isSectionTriplet_; }
+ void arrayElementReassigned() { arrayElementReassigned_ = true; }
+ void setOriginalName(parser::Name &name) {
+ originalName_ = common::Clone(name);
+ }
+ void setOriginalSource(parser::CharBlock &source) {
+ originalSource_ = source;
+ }
+ void setOriginalSubscriptInt(parser::CharBlock &subscript) {
+ originalSubscriptCharBlock_ = subscript;
+ }
+ void setTempScope(Scope &scope) { tempScope_ = &scope; };
+ void setTempSymbol(Symbol *symbol) { replacementTempSymbol_ = symbol; }
+
+private:
+ Symbol *replacementTempSymbol_{nullptr};
+ Symbol *privateReductionSymbol_{nullptr};
+ Scope *tempScope_{nullptr};
+ parser::CharBlock originalSource_;
+ parser::Name originalName_;
+ parser::CharBlock originalSubscriptCharBlock_;
+ bool arrayElementReassigned_{false};
+ bool isSectionTriplet_{false};
+};
+
+class RewriteOmpReductionArrayElements {
+public:
+ explicit RewriteOmpReductionArrayElements(SemanticsContext &context)
+ : context_(context) {}
+ // Default action for a parse tree node is to visit children.
+ template <typename T> bool Pre(T &) { return true; }
+ template <typename T> void Post(T &) {}
+
+ void Post(parser::Block &block);
+ void Post(parser::Variable &var);
+ void Post(parser::Expr &expr);
+ void Post(parser::AssignmentStmt &assignmentStmt);
+ void Post(parser::PointerAssignmentStmt &ptrAssignmentStmt);
+ void rewriteReductionArrayElementToTemp(parser::Block &block);
+ bool isArrayElementRewritten() { return arrayElementReassigned_; }
+
+private:
+ bool isMatchingArrayElement(parser::Designator &existingDesignator);
+ template <typename T>
+ void processFunctionReference(
+ T &node, parser::CharBlock source, parser::FunctionReference &funcRef);
+ parser::Designator makeTempDesignator(parser::CharBlock source);
+ bool rewriteArrayElementToTemp(parser::Block::iterator &it,
+ parser::OpenMPLoopConstruct &ompLoop, parser::Block &block,
+ ReplacementTemp &temp);
+ bool identifyArrayElementReduced(
+ parser::Designator &designator, ReplacementTemp &temp);
+ void reassignTempValueToArrayElement(parser::ArrayElement &arrayElement);
+ void setCurrentTemp(ReplacementTemp *temp) { currentTemp_ = temp; }
+ void resetCurrentTemp() { currentTemp_ = nullptr; }
+
+ SemanticsContext &context_;
+ bool arrayElementReassigned_{false};
+ parser::Block::iterator reassignmentInsertionPoint_;
+ parser::Block *block_{nullptr};
+ ReplacementTemp *currentTemp_{nullptr};
+};
+
// Check that name has been resolved to a symbol
void RewriteMutator::Post(parser::Name &name) {
if (!name.symbol && errorOnUnresolvedName_) {
@@ -492,10 +572,454 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
FixMisparsedUntaggedNamelistName(x);
}
+void ReplacementTemp::createTempSymbol(
+ SourceName &source, Scope &scope, SemanticsContext &context) {
+ replacementTempSymbol_ =
+ const_cast<semantics::Scope &>(originalName_.symbol->owner())
+ .FindSymbol(source);
+ replacementTempSymbol_->set_scope(
+ &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+ DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+ replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+ replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+ std::list<parser::SectionSubscript> §ionSubscript) {
+ bool setSubscript{false};
+ for (parser::SectionSubscript &subscript : sectionSubscript) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::IntExpr &intExpr) {
+ parser::Expr &expr = intExpr.thing.value();
+ std::visit(
+ llvm::makeVisitor(
+ [&](parser::LiteralConstant &literalContant) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::IntLiteralConstant
+ &intLiteralConstant) {
+ originalSubscriptCharBlock_ =
+ std::get<parser::CharBlock>(
+ intLiteralConstant.t);
+ setSubscript = true;
+ },
+ [&](auto &) {}),
+ literalContant.u);
+ },
+ [&](auto &) {}),
+ expr.u);
+ },
+ [&](parser::SubscriptTriplet &triplet) {
+ isSectionTriplet_ = true;
+ setSubscript = true;
+ },
+ [&](auto &) {}),
+ subscript.u);
+ if (setSubscript) {
+ break;
+ }
+ }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+ parser::Block &block) {
+ if (block.empty()) {
+ return;
+ }
+
+ for (auto it{block.begin()}; it != block.end(); ++it) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](parser::ExecutableConstruct &execConstruct) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](common::Indirection<parser::OpenMPConstruct>
+ &ompConstruct) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](parser::OpenMPLoopConstruct &ompLoop) {
+ ReplacementTemp temp;
+ if (!rewriteArrayElementToTemp(
+ it, ompLoop, block, temp)) {
+ return;
+ }
+ auto &NestedConstruct = std::get<
+ std::optional<parser::NestedConstruct>>(
+ ompLoop.t);
+ if (!NestedConstruct.has_value()) {
+ return;
+ }
+ if (parser::DoConstruct *
+ doConst{std::get_if<parser::DoConstruct>(
+ &NestedConstruct.value())}) {
+ block_ = █
+ parser::Block &doBlock{
+ std::get<parser::Block>(doConst->t)};
+ parser::Walk(doBlock, *this);
+ // Reset the current temp value so future
+ // iterations use their own version.
+ resetCurrentTemp();
+ }
+ },
+ [&](auto &) {}),
+ ompConstruct.value().u);
+ },
+ [&](auto &) {}),
+ execConstruct.u);
+ },
+ [&](auto &) {}),
+ it->u);
+ }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+ parser::Designator &existingDesignator) {
+ bool matchesArrayElement{false};
+ std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+ std::visit(llvm::makeVisitor(
+ [&](parser::DataRef &dataRef) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](common::Indirection<parser::ArrayElement>
+ &arrayElement) {
+ subscripts = &arrayElement.value().subscripts;
+ std::visit(
+ llvm::makeVisitor(
+ [&](parser::Name &name) {
+ if (name.symbol->GetUltimate() ==
+ currentTemp_->getOriginalName()
+ .symbol->GetUltimate()) {
+ matchesArrayElement = true;
+ if (!currentTemp_
+ ->isArrayElementReassigned()) {
+ reassignTempValueToArrayElement(
+ arrayElement.value());
+ }
+ }
+ },
+ [](auto &) {}),
+ arrayElement.value().base.u);
+ },
+ [&](parser::Name &name) {
+ if (name.symbol->GetUltimate() ==
+ currentTemp_->getOriginalName()
+ .symbol->GetUltimate()) {
+ matchesArrayElement = true;
+ }
+ },
+ [](auto &) {}),
+ dataRef.u);
+ },
+ [&](auto &) {}),
+ existingDesignator.u);
+
+ if (subscripts) {
+ bool foundSubscript{false};
+ for (parser::SectionSubscript &subscript : *subscripts) {
+ matchesArrayElement = std::visit(
+ llvm::makeVisitor(
+ [&](parser::IntExpr &intExpr) -> bool {
+ parser::Expr &expr = intExpr.thing.value();
+ return std::visit(
+ llvm::makeVisitor(
+ [&](parser::LiteralConstant &literalContant) -> bool {
+ return std::visit(
+ llvm::makeVisitor(
+ [&](parser::IntLiteralConstant
+ &intLiteralConstant) -> bool {
+ foundSubscript = true;
+ assert(currentTemp_ != nullptr &&
+ "Value for ReplacementTemp should have "
+ "been found");
+ if (std::get<parser::CharBlock>(
+ intLiteralConstant.t) ==
+ currentTemp_->getOriginalSubscript()) {
+ return true;
+ }
+ return false;
+ },
+ [](auto &) -> bool { return false; }),
+ literalContant.u);
+ },
+ [](auto &) -> bool { return false; }),
+ expr.u);
+ },
+ [](auto &) -> bool { return false; }),
+ subscript.u);
+ if (foundSubscript) {
+ break;
+ }
+ }
+ }
+ return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+ T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+ auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+ std::optional<parser::Designator> arrayElementDesignator =
+ std::visit(llvm::makeVisitor(
+ [&](parser::Name &functionReferenceName)
+ -> std::optional<parser::Designator> {
+ if (currentTemp_->getOriginalName().symbol ==
+ functionReferenceName.symbol) {
+ return funcRef.ConvertToArrayElementRef();
+ }
+ return std::nullopt;
+ },
+ [&](auto &) -> std::optional<parser::Designator> {
+ return std::nullopt;
+ }),
+ procedureDesignator.u);
+
+ if (arrayElementDesignator.has_value()) {
+ if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+ node = T{
+ common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+ }
+ }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+ parser::CharBlock source) {
+ parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+ tempVariableName.symbol = currentTemp_->getTempSymbol();
+ parser::Designator tempDesignator{
+ parser::DataRef{std::move(tempVariableName)}};
+ tempDesignator.source = source;
+ return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+ parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+ parser::Block &block, ReplacementTemp &temp) {
+ parser::OmpBeginLoopDirective &ompBeginLoop{
+ std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+ std::list<parser::OmpClause> &clauseList{
+ std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
+ bool rewrittenArrayElement{false};
+
+ for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+ rewrittenArrayElement = std::visit(
+ llvm::makeVisitor(
+ [&](parser::OmpClause::Reduction &clause) -> bool {
+ std::list<parser::OmpObject> &objectList =
+ std::get<parser::OmpObjectList>(clause.v.t).v;
+
+ bool rewritten{false};
+ for (auto object{objectList.begin()}; object != objectList.end();
+ ++object) {
+ rewritten = std::visit(
+ llvm::makeVisitor(
+ [&](parser::Designator &designator) -> bool {
+ if (!identifyArrayElementReduced(designator, temp)) {
+ return false;
+ }
+ if (temp.isSectionTriplet()) {
+ return false;
+ }
+
+ reassignmentInsertionPoint_ =
+ it != block.end() ? it : block.end();
+ std::string tempSourceString = "reduction_temp_" +
+ temp.getOriginalSource().ToString() + "(" +
+ temp.getOriginalSubscript().ToString() + ")";
+ SourceName source = context_.SaveTempName(
+ std::move(tempSourceString));
+ Scope &scope = const_cast<Scope &>(
+ temp.getOriginalName().symbol->owner());
+ if (Symbol * symbol{scope.FindSymbol(source)}) {
+ temp.setTempSymbol(symbol);
+ } else {
+ if (scope
+ .try_emplace(source, semantics::Attrs{},
+ semantics::ObjectEntityDetails{})
+ .second) {
+ temp.createTempSymbol(source, scope, context_);
+ } else {
+ common::die("Failed to create temp symbol for %s",
+ source.ToString().c_str());
+ }
+ }
+ setCurrentTemp(&temp);
+ temp.setTempScope(scope);
+
+ // Assign the value of the array element to the
+ // temporary variable
+ parser::Variable newVariable{
+ makeTempDesignator(temp.getOriginalSource())};
+ parser::Expr newExpr{
+ common::Indirection<parser::Designator>{
+ std::move(designator)}};
+ newExpr.source = temp.getOriginalSource();
+ std::tuple<parser::Variable, parser::Expr> newT{
+ std::move(newVariable), std::move(newExpr)};
+ parser::AssignmentStmt assignment{std::move(newT)};
+ parser::ExecutionPartConstruct
+ tempVariablePartConstruct{
+ parser::ExecutionPartConstruct{
+ parser::ExecutableConstruct{
+ parser::Statement<parser::ActionStmt>{
+ std::optional<parser::Label>{},
+ std::move(assignment)}}}};
+ block.insert(
+ it, std::move(tempVariablePartConstruct));
+ arrayElementReassigned_ = true;
+
+ designator =
+ makeTempDesignator(temp.getOriginalSource());
+ return true;
+ },
+ [&](const auto &) -> bool { return false; }),
+ object->u);
+ }
+ return rewritten;
+ },
+ [&](auto &) -> bool { return false; }),
+ iter->u);
+
+ if (rewrittenArrayElement) {
+ return rewrittenArrayElement;
+ }
+ }
+ return rewrittenArrayElement;
+}
+
+bool RewriteOmpReductionArrayElements::identifyArrayElementReduced(
+ parser::Designator &designator, ReplacementTemp &temp) {
+ return std::visit(
+ llvm::makeVisitor(
+ [&](parser::DataRef &dataRef) -> bool {
+ return std::visit(
+ llvm::makeVisitor(
+ [&](common::Indirection<parser::ArrayElement>
+ &arrayElement) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::Name &name) -> void {
+ temp.setOriginalName(name);
+ temp.setOriginalSource(name.source);
+ },
+ [&](auto &) -> void {}),
+ arrayElement.value().base.u);
+ temp.setOriginalSubscriptInt(
+ arrayElement.value().subscripts);
+ return !temp.isSectionTriplet() ? true : false;
+ },
+ [&](auto &) -> bool { return false; }),
+ dataRef.u);
+ },
+ [&](auto &) -> bool { return false; }),
+ designator.u);
+}
+
+void RewriteOmpReductionArrayElements::reassignTempValueToArrayElement(
+ parser::ArrayElement &arrayElement) {
+ assert(block_ && "Need iterator to reassign");
+ parser::CharBlock originalSource = currentTemp_->getOriginalSource();
+ parser::DataRef reassignmentDataRef{std::move(arrayElement)};
+ common::Indirection<parser::Designator> arrayElementDesignator{
+ std::move(reassignmentDataRef)};
+ arrayElementDesignator.value().source = originalSource;
+ parser::Variable exisitingVar{std::move(arrayElementDesignator)};
+ std::get<c...
[truncated]
|
I forgot what was the difficulty here in reduction. I suspect that we faced a similar issue in OpenMP atomic (example source below) and it was handled during lowering.
In other words, is it not possible to get a reference to the array element and then use that as the value in the reduction clause's argument in MLIR?
Something like the following:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments.
The large indentations in some places make the code harder to read. Could you extract some of the visitor lambdas into separate entities? e.g.
std::visit(
[&](...) {
[&] (...) {
}
}
)
to
auto foo = [&] (...) {
};
auto bar = [&] (...) {
foo
};
std::visit(...bar...);
Or something like that to reduce the amount of leading whitespace.
I'm also not exactly sure what the issue is. When you (Jack) say that the whole array is lowered, what do you mean exactly? Could you maybe share a small example to illustrate the problem? |
Correct me if I'm reading the code wrong, but it seems like you do the reassignment back to the original array only when that element is used, but you do the replacement with a scalar in all cases. If so, that will produce incorrect code:
|
@kiranchandramohan From my understanding addressing this in Lowering is not possible. An array element is an expression and the reference to the Symbol is used, which in this case would be the full array. This is ok for Array Sections, but if the user is only using 1 single array element, it does not make sense to reduce the full array, hence the rewrite to the temp so it can pick up a symbol which is just a single value. With this rewrite, the MLIR example you provided with a single integer being reduced is what ends up being generated, but its not possible to do this using the array element itself because of the process I explained above. @kparzysz an example of the lowering as it is would be, say for a small program
When the reduction appears in High Level FIR, its generated as
As the user only needs 1 element of the array, it does not make sense to use the whole array in the reduction. This patch would then turn that into the following, with the assignment to and from the temp in the relevant places:
With the example you have given in the subroutine |
When generating the body of the construct containing the reduction clause, we need to ensure that the references to the reduced value use the private reduction variable and not some host variable. The way this works for reductions (and I think privatisation) is that we create a host associated symbol for the variable being reduced and that shadows the original shared copy: ensuring that lowering of references to that symbol inside of the body instead point to the reduction block argument. This works well when whole symbols are reduced, but not when an expression is what is reduced. We don't have a convenient way to instruct lowering to replace all instances of a parse tree expression with a new value (the reduction block argument). I couldn't think of a way to add that without changing a lot of non-openmp code. This work around is to rewrite something like
into
This is a first step towards supporting reductions on components of derived types (which will be handled with a temporary in a similar way). |
I'm ok with this approach, but I think we need to develop a solution in lowering in the longer term. I'm concerned that doing something like the example above could potentially introduce subtle race conditions
I think we can probably get away with this for now. |
collate all std::visit's into lammbda functions to make them easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kparzysz thanks for the initial review. I have also gone through and extracted the std::visit
parts of the code into lambda functions so it is easier to follow and review as you suggested.
I haven't looked at this in detail. FYI: an expression to mlir::value map was added in #69944 for simplifying codegen for atomic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks a lot clearer now.
This code will only handle one reduction variable per reduction clause. For example, given
subroutine f00(x, y)
integer :: x(:), y(:)
integer :: i
!$omp parallel do reduction(+: x(2), y(3))
do i = 1, 200
x(2) = x(2) + i
y(3) = 6 - y(3)
end do
!$omp end parallel do
end
it generates
SUBROUTINE f00 (x, y)
INTEGER x(:), y(:)
INTEGER i
reduction_temp_x(2)=x(2_8)
reduction_temp_y(3)=y(3_8)
!$OMP PARALLEL DO REDUCTION(+: reduction_temp_x(2),reduction_temp_y(3))
DO i=1_4,200_4
x(2_8)=x(2_8)+i ! unmodified
reduction_temp_y(3)=6_4-reduction_temp_y(3)
END DO
!$OMP END PARALLEL DO
y(3_8)=reduction_temp_y(3)
END SUBROUTINE
leaving x(2) = x(2) + i
unmodified. Is that a known limitation?
Also, if you modify the testcase a bit:
subroutine f00(x, y, j)
integer :: x(:), y(:)
integer :: i, j
!$omp parallel do reduction(+: x(2), y(j))
do i = 1, 200
x(2) = x(2) + i
y(j) = 6 - y(j)
end do
!$omp end parallel do
end
you get incorrect output:
SUBROUTINE f00 (x, y, j)
INTEGER x(:), y(:)
INTEGER i, j
reduction_temp_x(2)=x(2_8)
reduction_temp_y(2)=y(int(j,kind=8))
!$OMP PARALLEL DO REDUCTION(+: reduction_temp_x(2),reduction_temp_y(2))
DO i=1_4,200_4
reduction_temp_y(2)=reduction_temp_y(2)+i
reduction_temp_y(2)=6_4-y(int(j,kind=8))
END DO
!$OMP END PARALLEL DO
y(int(j,kind=8))=reduction_temp_y(2)
END SUBROUTINE
} | ||
} | ||
}; | ||
auto visitArratElement = [&](parser::ArrayElement &arrayElement) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: arrat
parser::OmpBeginLoopDirective &ompBeginLoop{ | ||
std::get<parser::OmpBeginLoopDirective>(ompLoop.t)}; | ||
std::list<parser::OmpClause> &clauseList{ | ||
std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case you didn't know, OpenMPLoopConstruct has a function "BeginDir" that returns a reference to the contained "OmpBeginLoopDirective". The OmpBeginLoopDirective inherits from OmpDirectiveSpecification that has member function Clauses() that returns the list of clauses.
You could do
auto &clauseList{const_cast<std::list<parser::OmpClause>>(ompLoop).BeginDir().Clauses().v};
or
auto &clauseList{const_cast<parser::OmpClauseList>(ompLoop).BeginDir().Clauses()};
and use clauseList.v.
std::list<parser::OmpObject> &objectList = | ||
std::get<parser::OmpObjectList>(clause.v.t).v; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, there is a function in flang/include/flang/Parser/openmp-utils.h
that gets OmpObjectList from any clause (or returns nullptr if the clause doesn't have one):
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
Currently, Flang does not correctly lower ArrayElement's when processing an OpenMP Reduction Clause correctly. Rather than lowering the array element, the whole array will be lowered. This leads to slower performance for the end user in their program.
This patch works to rectify this by rewriteing the parse tree while processing semantics. The use of an ArrayElement in an OpenMP Reduction Clause will be identified, and replaced with a temporary both in the reduction clause, and anywhere that array element is used within the respective DoConstruct. Once the DoConstruct has finished, if the ArrayElement has been used within the Do loop, the value of the temporary will be re-assigned to the array element. One limitation of this approach is that if the ArrayElement is not used, there is no available element in the parse tree to use to reassign the value, so its only done if used.
The reason for making the change in the parse tree is due to how ArrayElements are lowered. When lowering, the expression of the ArrayElement being used in the reduction is being substitued with the reference to the symbol. In this case, that would be the whole array. By replacing it with a temporary, it removes the issue of lowering a full array as it will be referencing the temporary instead. To address this in lowering would require a major rethink on how a considerable amount of non-OpenMP code is lowered and as such, not deemed the appropriate course of action for this specific case.
This process is done after the initial Semantics Pass as to not affect the checking of users original code. If the array element has been replaced, the first pass of semantics will need to be rerun to ensure all TypedExpr's are correctly captured otherwise the lowering will not function correctly. This step is only done if an ArrayElement is replaced.
Testing is covered by reduction17.f90. This checks both the parse tree, unparsing and HLFIR to ensure the temproary is being used in the reduction clause and Do loop. Assignment to, and reassignment from the ArrayElement and the Temporary is also considered to ensure this is inserted at the correct location.
reduction09.f90 has also been reformatted to rely on FileCheck. As the Parse Tree is changing, the output is different to that of the user, so we can no longer rely on test_symbols.py for this test. The same information is being checked, with test cases that cover using an ArrayElement in the Do loop, and not using the ArrayElement being covered.
Array Sections are not affected by this change, only uses of single array elements.