Skip to content
30 changes: 14 additions & 16 deletions flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class StringRef;
} // namespace llvm

namespace mlir {
namespace func {
class FuncOp;
}
class Location;
class Type;
class ModuleOp;
Expand All @@ -34,6 +37,10 @@ class FirOpBuilder;
}

namespace Fortran {
namespace evaluate {
class ProcedureDesignator;
} // namespace evaluate

namespace parser {
struct AccClauseList;
struct OpenACCConstruct;
Expand All @@ -42,6 +49,7 @@ struct OpenACCRoutineConstruct;
} // namespace parser

namespace semantics {
class OpenACCRoutineInfo;
class SemanticsContext;
class Symbol;
} // namespace semantics
Expand All @@ -55,9 +63,6 @@ namespace pft {
struct Evaluation;
} // namespace pft

using AccRoutineInfoMappingList =
llvm::SmallVector<std::pair<std::string, mlir::SymbolRefAttr>>;

static constexpr llvm::StringRef declarePostAllocSuffix =
"_acc_declare_update_desc_post_alloc";
static constexpr llvm::StringRef declarePreDeallocSuffix =
Expand All @@ -71,19 +76,12 @@ mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
const parser::OpenACCConstruct &);
void genOpenACCDeclarativeConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
StatementContext &,
const parser::OpenACCDeclarativeConstruct &,
AccRoutineInfoMappingList &);
void genOpenACCRoutineConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
mlir::ModuleOp,
const parser::OpenACCRoutineConstruct &,
AccRoutineInfoMappingList &);

void finalizeOpenACCRoutineAttachment(mlir::ModuleOp,
AccRoutineInfoMappingList &);
void genOpenACCDeclarativeConstruct(
AbstractConverter &, Fortran::semantics::SemanticsContext &,
StatementContext &, const parser::OpenACCDeclarativeConstruct &);
void genOpenACCRoutineConstruct(
AbstractConverter &, mlir::ModuleOp, mlir::func::FuncOp,
const std::vector<Fortran::semantics::OpenACCRoutineInfo> &);

/// Get a acc.private.recipe op for the given type or create it if it does not
/// exist yet.
Expand Down
41 changes: 32 additions & 9 deletions flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <list>
#include <optional>
#include <set>
#include <variant>
#include <vector>

namespace llvm {
Expand Down Expand Up @@ -127,6 +128,8 @@ class WithBindName {
// Device type specific OpenACC routine information
class OpenACCRoutineDeviceTypeInfo {
public:
OpenACCRoutineDeviceTypeInfo(Fortran::common::OpenACCDeviceType dType)
: deviceType_{dType} {}
bool isSeq() const { return isSeq_; }
void set_isSeq(bool value = true) { isSeq_ = value; }
bool isVector() const { return isVector_; }
Expand All @@ -137,22 +140,28 @@ class OpenACCRoutineDeviceTypeInfo {
void set_isGang(bool value = true) { isGang_ = value; }
unsigned gangDim() const { return gangDim_; }
void set_gangDim(unsigned value) { gangDim_ = value; }
const std::string *bindName() const {
return bindName_ ? &*bindName_ : nullptr;
const std::variant<std::string, SymbolRef> *bindName() const {
return bindName_.has_value() ? &*bindName_ : nullptr;
}
void set_bindName(std::string &&name) { bindName_ = std::move(name); }
void set_dType(Fortran::common::OpenACCDeviceType dType) {
deviceType_ = dType;
const std::optional<std::variant<std::string, SymbolRef>> &
bindNameOpt() const {
return bindName_;
}
void set_bindName(std::string &&name) { bindName_.emplace(std::move(name)); }
void set_bindName(SymbolRef symbol) { bindName_.emplace(symbol); }

Fortran::common::OpenACCDeviceType dType() const { return deviceType_; }

friend llvm::raw_ostream &operator<<(
llvm::raw_ostream &, const OpenACCRoutineDeviceTypeInfo &);

private:
bool isSeq_{false};
bool isVector_{false};
bool isWorker_{false};
bool isGang_{false};
unsigned gangDim_{0};
std::optional<std::string> bindName_;
std::optional<std::variant<std::string, SymbolRef>> bindName_;
Fortran::common::OpenACCDeviceType deviceType_{
Fortran::common::OpenACCDeviceType::None};
};
Expand All @@ -162,15 +171,29 @@ class OpenACCRoutineDeviceTypeInfo {
// in as objects in the OpenACCRoutineDeviceTypeInfo list.
class OpenACCRoutineInfo : public OpenACCRoutineDeviceTypeInfo {
public:
OpenACCRoutineInfo()
: OpenACCRoutineDeviceTypeInfo(Fortran::common::OpenACCDeviceType::None) {
}
bool isNohost() const { return isNohost_; }
void set_isNohost(bool value = true) { isNohost_ = value; }
std::list<OpenACCRoutineDeviceTypeInfo> &deviceTypeInfos() {
const std::list<OpenACCRoutineDeviceTypeInfo> &deviceTypeInfos() const {
return deviceTypeInfos_;
}
void add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo &info) {
deviceTypeInfos_.push_back(info);

OpenACCRoutineDeviceTypeInfo &add_deviceTypeInfo(
Fortran::common::OpenACCDeviceType type) {
return add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo(type));
}

OpenACCRoutineDeviceTypeInfo &add_deviceTypeInfo(
OpenACCRoutineDeviceTypeInfo &&info) {
deviceTypeInfos_.push_back(std::move(info));
return deviceTypeInfos_.back();
}

friend llvm::raw_ostream &operator<<(
llvm::raw_ostream &, const OpenACCRoutineInfo &);

private:
std::list<OpenACCRoutineDeviceTypeInfo> deviceTypeInfos_;
bool isNohost_{false};
Expand Down
54 changes: 24 additions & 30 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
createGlobalOutsideOfFunctionLowering(
[&]() { declareFunction(f); });
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = f.getScope().symbol();
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::ContainedUnit &unit :
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
&unit))
declareFunction(*f);
createGlobalOutsideOfFunctionLowering([&]() {
for (Fortran::lower::pft::ContainedUnit &unit :
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
&unit))
declareFunction(*f);
});
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
Expand All @@ -438,14 +441,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(
bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
Fortran::lower::genOpenACCRoutineConstruct(
*this, bridge.getSemanticsContext(), bridge.getModule(),
d.routine, accRoutineInfos);
builder = nullptr;
},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
},
u);
}
Expand All @@ -465,12 +461,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::CUDA));
});

finalizeOpenACCLowering();
finalizeOpenMPLowering(globalOmpRequiresSymbol);
}

/// Declare a function.
void declareFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
CHECK(builder && "declareFunction called with uninitialized builder");
setCurrentPosition(funit.getStartingSourceLoc());
for (int entryIndex = 0, last = funit.entryPointList.size();
entryIndex < last; ++entryIndex) {
Expand Down Expand Up @@ -1035,7 +1031,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return bridge.getSemanticsContext().FindScope(currentPosition);
}

fir::FirOpBuilder &getFirOpBuilder() override final { return *builder; }
fir::FirOpBuilder &getFirOpBuilder() override final {
CHECK(builder && "builder is not set before calling getFirOpBuilder");
return *builder;
}

mlir::ModuleOp getModuleOp() override final { return bridge.getModule(); }

Expand Down Expand Up @@ -3018,8 +3017,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {

void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
genOpenACCDeclarativeConstruct(*this, bridge.getSemanticsContext(),
bridge.openAccCtx(), accDecl,
accRoutineInfos);
bridge.openAccCtx(), accDecl);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
}
Expand Down Expand Up @@ -5612,6 +5610,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
LLVM_DEBUG(llvm::dbgs() << "\n[bridge - startNewFunction]";
if (auto *sym = scope.symbol()) llvm::dbgs() << " " << *sym;
llvm::dbgs() << "\n");
// Setting the builder is not necessary here, because callee
// always looks up the FuncOp from the module. If there was a function that
// was not declared yet, this call to callee will cause an assertion
// failure.
Fortran::lower::CalleeInterface callee(funit, *this);
mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
builder =
Expand Down Expand Up @@ -5881,7 +5883,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Helper to generate GlobalOps when the builder is not positioned in any
/// region block. This is required because the FirOpBuilder assumes it is
/// always positioned inside a region block when creating globals, the easiest
/// way comply is to create a dummy function and to throw it afterwards.
/// way to comply is to create a dummy function and to throw it away
/// afterwards.
void createGlobalOutsideOfFunctionLowering(
const std::function<void()> &createGlobals) {
// FIXME: get rid of the bogus function context and instantiate the
Expand All @@ -5894,6 +5897,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::FunctionType::get(context, std::nullopt, std::nullopt),
symbolTable);
func.addEntryBlock();
CHECK(!builder && "Expected builder to be uninitialized");
builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
Expand Down Expand Up @@ -6323,13 +6327,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
expr.u);
}

/// Performing OpenACC lowering action that were deferred to the end of
/// lowering.
void finalizeOpenACCLowering() {
Fortran::lower::finalizeOpenACCRoutineAttachment(getModuleOp(),
accRoutineInfos);
}

/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
Expand Down Expand Up @@ -6421,9 +6418,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// A counter for uniquing names in `literalNamesMap`.
std::uint64_t uniqueLitId = 0;

/// Deferred OpenACC routine attachment.
Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;

/// Whether an OpenMP target region or declare target function/subroutine
/// intended for device offloading has been detected
bool ompDeviceCodeFound = false;
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Lower/CallInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "flang/Evaluate/fold.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/OpenACC.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/Utils.h"
Expand Down Expand Up @@ -715,6 +716,17 @@ void Fortran::lower::CallInterface<T>::declare() {
func.setArgAttrs(placeHolder.index(), placeHolder.value().attributes);

setCUDAAttributes(func, side().getProcedureSymbol(), characteristic);

if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
if (const auto &info{
sym->GetUltimate()
.detailsIf<Fortran::semantics::SubprogramDetails>()}) {
if (!info->openACCRoutineInfos().empty()) {
genOpenACCRoutineConstruct(converter, module, func,
info->openACCRoutineInfos());
}
}
}
}
}
}
Expand Down
Loading