Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
const parser::OpenACCConstruct &);
const parser::OpenACCConstruct &,
Fortran::lower::SymMap &localSymbols);
void genOpenACCDeclarativeConstruct(
AbstractConverter &, Fortran::semantics::SemanticsContext &,
StatementContext &, const parser::OpenACCDeclarativeConstruct &);
Expand Down
4 changes: 4 additions & 0 deletions flang/include/flang/Lower/SymbolMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ class SymMap {
return lookupSymbol(*sym);
}

/// Find a symbol by name and return its value if it appears in the current
/// mappings. This lookup is more expensive as it iterates over the map.
const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);

/// Find `symbol` and return its value if it appears in the inner-most level
/// map.
SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ class Symbol {
AccPrivate, AccFirstPrivate, AccShared,
// OpenACC data-mapping attribute
AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
// OpenACC declare
AccDeclare,
// OpenACC data-movement attribute
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
*this, bridge.getSemanticsContext(), getEval(), acc);
*this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);

const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
Expand Down
34 changes: 25 additions & 9 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
const Fortran::parser::AccClauseList &accClauseList,
Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
Expand All @@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
// When CUDA Fotran is enabled, extra symbolds are used in the host_data
// region. Look for them and bind their value with the symbol in the outer
// scope.
if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
const Fortran::parser::AccObjectList &objectList{useDevice->v};
for (const auto &accObject : objectList.v) {
Fortran::semantics::Symbol &symbol =
getSymbolFromAccObject(accObject);
const Fortran::semantics::Symbol *baseSym =
localSymbols.lookupSymbolByName(symbol.name().ToString());
localSymbols.copySymbolBinding(*baseSym, symbol);
}
}
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
Expand Down Expand Up @@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}

static void
genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
static void genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
Expand Down Expand Up @@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
stmtCtx, accClauseList);
stmtCtx, accClauseList, localSymbols);
}
}

Expand Down Expand Up @@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCConstruct &accConstruct) {
const Fortran::parser::OpenACCConstruct &accConstruct,
Fortran::lower::SymMap &localSymbols) {

mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
genACC(converter, semanticsContext, eval, blockConstruct);
genACC(converter, semanticsContext, eval, blockConstruct,
localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Lower/SymbolMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}

const Fortran::semantics::Symbol *
Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
jmap != jend; ++jmap)
for (auto const &[sym, symBox] : *jmap)
if (sym->name().ToString() == symName)
return sym;
return nullptr;
}

Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Semantics/check-declarations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
}
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
symbol.owner().kind() != Scope::Kind::MainProgram &&
symbol.owner().kind() != Scope::Kind::BlockConstruct) {
symbol.owner().kind() != Scope::Kind::BlockConstruct &&
symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
messages_.Say(
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
parser::ToUpperCaseLetters(common::EnumToString(attr)));
Expand Down
5 changes: 5 additions & 0 deletions flang/lib/Semantics/resolve-directives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
return false;
}

bool Pre(const parser::AccClause::UseDevice &x) {
ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
return false;
}

void Post(const parser::Name &);

private:
Expand Down
65 changes: 64 additions & 1 deletion flang/lib/Semantics/resolve-names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
// Create scopes for OpenACC constructs
class AccVisitor : public virtual DeclarationVisitor {
public:
explicit AccVisitor(SemanticsContext &context) : context_{context} {}

void AddAccSourceRange(const parser::CharBlock &);

static bool NeedsScope(const parser::OpenACCBlockConstruct &);
Expand All @@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
bool Pre(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
Expand Down Expand Up @@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::AccBeginLoopDirective &x) {
messageHandler().set_currStmtSource(std::nullopt);
}

void CopySymbolWithDevice(const parser::Name *name);

private:
SemanticsContext &context_;
};

bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
Expand Down Expand Up @@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
return true;
}

void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
// When CUDA Fortran is enabled together with OpenACC, new
// symbols are created for the one appearing in the use_device
// clause. These new symbols have the CUDA Fortran device
// attribute.
if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
name->symbol = currScope().CopySymbol(*name->symbol);
if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
object->set_cudaDataAttr(common::CUDADataAttr::Device);
}
}
}

bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
for (const auto &accObject : x.v.v) {
common::visit(
common::visitors{
[&](const parser::Designator &designator) {
if (const auto *name{
semantics::getDesignatorNameIfDataRef(designator)}) {
Symbol *prev{currScope().FindSymbol(name->source)};
if (prev != name->symbol) {
name->symbol = prev;
}
CopySymbolWithDevice(name);
} else {
if (const auto *dataRef{
std::get_if<parser::DataRef>(&designator.u)}) {
using ElementIndirection =
common::Indirection<parser::ArrayElement>;
if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
const parser::ArrayElement &arrayElement{ind->value()};
Walk(arrayElement.subscripts);
const parser::DataRef &base{arrayElement.base};
if (auto *name{std::get_if<parser::Name>(&base.u)}) {
Symbol *prev{currScope().FindSymbol(name->source)};
if (prev != name->symbol) {
name->symbol = prev;
}
CopySymbolWithDevice(name);
}
}
}
}
},
[&](const parser::Name &name) {
// TODO: common block in use_device?
},
},
accObject.u);
}
return false;
}

void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
Expand Down Expand Up @@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,

ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
: BaseVisitor{context, *this, rules}, topScope_{top} {
: BaseVisitor{context, *this, rules}, AccVisitor(context),
topScope_{top} {
PushScope(top);
}

Expand Down
43 changes: 43 additions & 0 deletions flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s

module m

interface doit
subroutine __device_sub(a)
real(4), device, intent(in) :: a(:,:,:)
!dir$ ignore_tkr(c) a
end
subroutine __host_sub(a)
real(4), intent(in) :: a(:,:,:)
!dir$ ignore_tkr(c) a
end
end interface
end module

program testex1
integer, parameter :: ntimes = 10
integer, parameter :: ni=128
integer, parameter :: nj=256
integer, parameter :: nk=64
real(4), dimension(ni,nj,nk) :: a

!$acc enter data copyin(a)

block; use m
!$acc host_data use_device(a)
do nt = 1, ntimes
call doit(a)
end do
!$acc end host_data
end block

block; use m
do nt = 1, ntimes
call doit(a)
end do
end block
end

! CHECK: fir.call @_QP__device_sub
! CHECK: fir.call @_QP__host_sub
Loading