Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions flang/include/flang/Lower/Support/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ namespace omp {

class ReductionProcessor {
public:
using GenInitValueCBTy =
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig)>;
using GenCombinerCBTy = std::function<void(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
mlir::Value op1, mlir::Value op2, bool isByRef)>;

// TODO: Move this enumeration to the OpenMP dialect
enum ReductionIdentifier {
ID,
Expand All @@ -58,6 +65,9 @@ class ReductionProcessor {
IEOR
};

static bool doReductionByRef(mlir::Type reductionType);
static bool doReductionByRef(mlir::Value reductionVar);

static ReductionIdentifier
getReductionType(const omp::clause::ProcedureDesignator &pd);

Expand Down Expand Up @@ -109,6 +119,14 @@ class ReductionProcessor {
ReductionIdentifier redId,
mlir::Type type, mlir::Value op1,
mlir::Value op2);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The init and combiner regions are generated by the callback
/// functions genCombinerCB and genInitValueCB.
template <typename DeclareRedType>
static DeclareRedType createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);

/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
Expand Down
60 changes: 60 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ClauseProcessor.h"
#include "Utils.h"

#include "flang/Lower/ConvertCall.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
Expand Down Expand Up @@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
return false;
}

bool ClauseProcessor::processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
mlir::Value ompPrivVar;
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
assert(name.symbol && "Name does not have a symbol");
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp = hlfir::DeclareOp::create(
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
0, attributes);
if (name.ToString() == "omp_priv")
ompPrivVar = declareOp.getResult(0);
symMap.addVariableDefinition(*name.symbol, declareOp);
}
// Lower the expression/function call
lower::StatementContext stmtCtx;
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
symMap, stmtCtx);
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
return privVal;
},
[&](const auto &expr) -> mlir::Value {
mlir::Value exprResult = fir::getBase(convertExprToValue(
loc, converter, clause->v, symMap, stmtCtx));
// Conversion can either give a value or a refrence to a value,
// we need to return the reduction type, so an optional load may
// be generated.
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
exprResult.getType()))
if (ompPrivVar.getType() == refType)
exprResult = fir::LoadOp::create(builder, loc, exprResult);
return exprResult;
}},
clause->v.u);
stmtCtx.finalizeAndPop();
return result;
};
return true;
}
return false;
}

bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
Expand Down Expand Up @@ -88,6 +89,9 @@ class ClauseProcessor {
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
Expand Down
17 changes: 16 additions & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,

Initializer make(const parser::OmpClause::Initializer &inp,
semantics::SemanticsContext &semaCtx) {
llvm_unreachable("Empty: initializer");
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
auto &expr = std::get<parser::Expr>(as->t);
return Initializer{makeExpr(expr, semaCtx)};
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
if (call->typedCall) {
const auto &procRef = *call->typedCall;
semantics::SomeExpr evalProcRef{procRef};
return Initializer{evalProcRef};
}
}

llvm_unreachable("Unexpected initializer");
}

InReduction make(const parser::OmpClause::InReduction &inp,
Expand Down
133 changes: 130 additions & 3 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include "Decomposer.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
Expand Down Expand Up @@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;

return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
Expand Down Expand Up @@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}

static bool
processReductionCombiner(lower::AbstractConverter &converter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return GenCombinerCBTy instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done!

lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
const parser::OmpReductionSpecifier &specifier,
ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the caller expected to behave in case there is no AssignmentStmt? Should we signal this as an error somehow? Or assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a TODO in the case where a subroutine is called. So I did an early return check with TODO.

auto &expr = std::get<parser::Expr>(as->t);
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
bool isByRef) {
const auto &evalExpr = makeExpr(expr, semaCtx);
lower::SymMapScope scope(symTable);
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(
combinerInstance.t);
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
bool isRhs = name.ToString() == std::string("omp_in");
if (isRhs) {
addr = rhs;
type = rhs.getType();
}

assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
}
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp = hlfir::DeclareOp::create(
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
0, attributes);
symTable.addVariableDefinition(*name.symbol, declareOp);
}

lower::StatementContext stmtCtx;
mlir::Value result = fir::getBase(
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
if (lhs.getType() == refType.getElementType())
result = fir::LoadOp::create(builder, loc, result);
stmtCtx.finalizeAndPop();
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
mlir::omp::YieldOp::create(builder, loc, lhs);
} else {
mlir::omp::YieldOp::create(builder, loc, result);
}

return result;
};
}
return true;
}

// Getting the type from a symbol compared to a DeclSpec is simpler since we do
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
// generate these symbols.
static mlir::Type
getReductionType(lower::AbstractConverter &converter,
const parser::OmpReductionSpecifier &specifier) {
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
const parser::OmpStylizedDeclaration &decl = declList.front();
const auto &name = std::get<parser::ObjectName>(decl.var.t);
const auto &symbol = semantics::SymbolRef(*name.symbol);
mlir::Type reductionType = converter.genType(symbol);
return reductionType;
}

static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
if (!semaCtx.langOptions().OpenMPSimd)
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
if (!semaCtx.langOptions().OpenMPSimd) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think an early return would be more readable here specially given the nested if conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, early return is better. Fixed.

const parser::OmpArgumentList &args{
declareReductionConstruct.v.Arguments()};
const parser::OmpArgument &arg{args.v.front()};
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);

if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare target is not yet supported");

mlir::Type reductionType = getReductionType(converter, specifier);
ReductionProcessor::GenCombinerCBTy genCombinerCB;
processReductionCombiner(converter, symTable, semaCtx, specifier,
genCombinerCB);
const parser::OmpClauseList &initializer =
declareReductionConstruct.v.Clauses();
if (initializer.v.size() > 0) {
List<Clause> clauses = makeClauses(initializer, semaCtx);
ReductionProcessor::GenInitValueCBTy genInitValueCB;
ClauseProcessor cp(converter, semaCtx, clauses);
const parser::OmpClause::Initializer &iclause{
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
cp.processInitializer(symTable, iclause, genInitValueCB);
const auto &identifier =
std::get<parser::OmpReductionIdentifier>(specifier.t);
const auto &designator =
std::get<parser::ProcedureDesignator>(identifier.u);
const auto &reductionName = std::get<parser::Name>(designator.u);
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
ReductionProcessor::createDeclareReductionHelper<
mlir::omp::DeclareReductionOp>(
converter, reductionName.ToString(), reductionType,
converter.getCurrentLocation(), isByRef, genCombinerCB,
genInitValueCB);
} else {
TODO(converter.getCurrentLocation(),
"declare target without an initializer clause is not yet supported");
}
}
}

static void
Expand Down
Loading
Loading