Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -1286,17 +1286,6 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
const std::optional<ActualArgument> &, const std::string &procName,
const std::string &argName);

inline bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
return false;
}
}
return true;
}

inline bool IsCUDADeviceSymbol(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Semantics/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
return false;
}

bool CanCUDASymbolBeGlobal(const Symbol &sym);

const Scope *FindCUDADeviceContext(const Scope *);
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *);

Expand Down
12 changes: 5 additions & 7 deletions flang/lib/Evaluate/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2194,8 +2194,7 @@ bool IsSaved(const Symbol &original) {
return false;
} else if (scopeKind == Scope::Kind::Module ||
(scopeKind == Scope::Kind::MainProgram &&
(symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)) &&
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol))) {
(symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)))) {
// 8.5.16p4
// In main programs, implied SAVE matters only for pointer
// initialization targets and coarrays.
Expand All @@ -2204,8 +2203,7 @@ bool IsSaved(const Symbol &original) {
(features.IsEnabled(common::LanguageFeature::SaveMainProgram) ||
(features.IsEnabled(
common::LanguageFeature::SaveBigMainProgramVariables) &&
symbol.size() > 32)) &&
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol)) {
symbol.size() > 32))) {
// With SaveBigMainProgramVariables, keeping all unsaved main program
// variables of 32 bytes or less on the stack allows keeping numerical and
// logical scalars, small scalar characters or derived, small arrays, and
Expand All @@ -2223,15 +2221,15 @@ bool IsSaved(const Symbol &original) {
} else if (symbol.test(Symbol::Flag::InDataStmt)) {
return true;
} else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()};
object && object->init()) {
object && object->init()) {
return true;
} else if (IsProcedurePointer(symbol) && symbol.has<ProcEntityDetails>() &&
symbol.get<ProcEntityDetails>().init()) {
return true;
} else if (scope.hasSAVE()) {
return true; // bare SAVE statement
} else if (const Symbol * block{FindCommonBlockContaining(symbol)};
block && block->attrs().test(Attr::SAVE)) {
} else if (const Symbol *block{FindCommonBlockContaining(symbol)};
block && block->attrs().test(Attr::SAVE)) {
return true; // in COMMON with SAVE
} else {
return false;
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/PFTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1472,8 +1472,8 @@ bool Fortran::lower::definedInCommonBlock(const semantics::Symbol &sym) {

/// Is the symbol `sym` a global?
bool Fortran::lower::symbolIsGlobal(const semantics::Symbol &sym) {
return semantics::IsSaved(sym) || lower::definedInCommonBlock(sym) ||
semantics::IsNamedConstant(sym);
return (semantics::IsSaved(sym) && semantics::CanCUDASymbolBeGlobal(sym)) ||
lower::definedInCommonBlock(sym) || semantics::IsNamedConstant(sym);
}

namespace {
Expand Down
48 changes: 47 additions & 1 deletion flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,52 @@ std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
return object ? object->cudaDataAttr() : std::nullopt;
}

bool IsDeviceAllocatable(const Symbol &symbol) {
if (IsAllocatable(symbol)) {
if (const auto *details{
symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
return true;
}
}
}
return false;
}

UltimateComponentIterator::const_iterator
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) {
UltimateComponentIterator ultimates{derived};
return std::find_if(ultimates.begin(), ultimates.end(), IsDeviceAllocatable);
}

bool CanCUDASymbolBeGlobal(const Symbol &sym) {
const Symbol &symbol{GetAssociationRoot(sym)};
const Scope &scope{symbol.owner()};
auto scopeKind{scope.kind()};
const common::LanguageFeatureControl &features{
scope.context().languageFeatures()};
if (features.IsEnabled(common::LanguageFeature::CUDA) &&
scopeKind == Scope::Kind::MainProgram) {
if (const auto *details{
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{
type ? type->AsDerived() : nullptr};
if (derived) {
if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
return false;
}
}
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
return false;
}
}
}
return true;
}

bool IsAccessible(const Symbol &original, const Scope &scope) {
const Symbol &ultimate{original.GetUltimate()};
if (ultimate.attrs().test(Attr::PRIVATE)) {
Expand Down Expand Up @@ -1788,4 +1834,4 @@ bool HadUseError(
}
}

} // namespace Fortran::semantics
} // namespace Fortran::semantics
20 changes: 20 additions & 0 deletions flang/test/Lower/CUDA/cuda-derived.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

module m1
type ty_device
integer, device, allocatable, dimension(:) :: x
end type

type t1; real, device, allocatable :: a(:); end type
type t2; type(t1) :: b; end type
end module

program main
use m1
type(ty_device) :: a
type(t2) :: b
end

! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}
Loading