diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index cad1b634f8924..18c244f6f450f 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1286,17 +1286,6 @@ bool CheckForCoindexedObject(parser::ContextualMessages &, const std::optional &, const std::string &procName, const std::string &argName); -inline bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) { - if (const auto *details = - sym.GetUltimate().detailsIf()) { - if (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Unified) { - return false; - } - } - return true; -} - inline bool IsCUDADeviceSymbol(const Symbol &sym) { if (const auto *details = sym.GetUltimate().detailsIf()) { diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h index f3cfa9b99fb4d..ea07128a6d240 100644 --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -248,6 +248,8 @@ inline bool NeedCUDAAlloc(const Symbol &sym) { return false; } +bool CanCUDASymbolBeGlobal(const Symbol &sym); + const Scope *FindCUDADeviceContext(const Scope *); std::optional GetCUDADataAttr(const Symbol *); diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index fcacdb93d662b..6a57d87a30e93 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -2194,8 +2194,7 @@ bool IsSaved(const Symbol &original) { return false; } else if (scopeKind == Scope::Kind::Module || (scopeKind == Scope::Kind::MainProgram && - (symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)) && - Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol))) { + (symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)))) { // 8.5.16p4 // In main programs, implied SAVE matters only for pointer // initialization targets and coarrays. @@ -2204,8 +2203,7 @@ bool IsSaved(const Symbol &original) { (features.IsEnabled(common::LanguageFeature::SaveMainProgram) || (features.IsEnabled( common::LanguageFeature::SaveBigMainProgramVariables) && - symbol.size() > 32)) && - Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol)) { + symbol.size() > 32))) { // With SaveBigMainProgramVariables, keeping all unsaved main program // variables of 32 bytes or less on the stack allows keeping numerical and // logical scalars, small scalar characters or derived, small arrays, and @@ -2223,15 +2221,15 @@ bool IsSaved(const Symbol &original) { } else if (symbol.test(Symbol::Flag::InDataStmt)) { return true; } else if (const auto *object{symbol.detailsIf()}; - object && object->init()) { + object && object->init()) { return true; } else if (IsProcedurePointer(symbol) && symbol.has() && symbol.get().init()) { return true; } else if (scope.hasSAVE()) { return true; // bare SAVE statement - } else if (const Symbol * block{FindCommonBlockContaining(symbol)}; - block && block->attrs().test(Attr::SAVE)) { + } else if (const Symbol *block{FindCommonBlockContaining(symbol)}; + block && block->attrs().test(Attr::SAVE)) { return true; // in COMMON with SAVE } else { return false; diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp index 68023610c3c50..a28cc019a7310 100644 --- a/flang/lib/Lower/PFTBuilder.cpp +++ b/flang/lib/Lower/PFTBuilder.cpp @@ -1472,8 +1472,8 @@ bool Fortran::lower::definedInCommonBlock(const semantics::Symbol &sym) { /// Is the symbol `sym` a global? bool Fortran::lower::symbolIsGlobal(const semantics::Symbol &sym) { - return semantics::IsSaved(sym) || lower::definedInCommonBlock(sym) || - semantics::IsNamedConstant(sym); + return (semantics::IsSaved(sym) && semantics::CanCUDASymbolBeGlobal(sym)) || + lower::definedInCommonBlock(sym) || semantics::IsNamedConstant(sym); } namespace { diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index d053179448c00..aed57216f13b7 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -1087,6 +1087,52 @@ std::optional GetCUDADataAttr(const Symbol *symbol) { return object ? object->cudaDataAttr() : std::nullopt; } +bool IsDeviceAllocatable(const Symbol &symbol) { + if (IsAllocatable(symbol)) { + if (const auto *details{ + symbol.GetUltimate().detailsIf()}) { + if (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { + return true; + } + } + } + return false; +} + +UltimateComponentIterator::const_iterator +FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) { + UltimateComponentIterator ultimates{derived}; + return std::find_if(ultimates.begin(), ultimates.end(), IsDeviceAllocatable); +} + +bool CanCUDASymbolBeGlobal(const Symbol &sym) { + const Symbol &symbol{GetAssociationRoot(sym)}; + const Scope &scope{symbol.owner()}; + auto scopeKind{scope.kind()}; + const common::LanguageFeatureControl &features{ + scope.context().languageFeatures()}; + if (features.IsEnabled(common::LanguageFeature::CUDA) && + scopeKind == Scope::Kind::MainProgram) { + if (const auto *details{ + sym.GetUltimate().detailsIf()}) { + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{ + type ? type->AsDerived() : nullptr}; + if (derived) { + if (FindCUDADeviceAllocatableUltimateComponent(*derived)) { + return false; + } + } + if (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Unified) { + return false; + } + } + } + return true; +} + bool IsAccessible(const Symbol &original, const Scope &scope) { const Symbol &ultimate{original.GetUltimate()}; if (ultimate.attrs().test(Attr::PRIVATE)) { @@ -1788,4 +1834,4 @@ bool HadUseError( } } -} // namespace Fortran::semantics \ No newline at end of file +} // namespace Fortran::semantics diff --git a/flang/test/Lower/CUDA/cuda-derived.cuf b/flang/test/Lower/CUDA/cuda-derived.cuf new file mode 100644 index 0000000000000..d280ac722d08f --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-derived.cuf @@ -0,0 +1,20 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +module m1 + type ty_device + integer, device, allocatable, dimension(:) :: x + end type + + type t1; real, device, allocatable :: a(:); end type + type t2; type(t1) :: b; end type +end module + +program main + use m1 + type(ty_device) :: a + type(t2) :: b +end + +! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"} +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box>>}> {bindc_name = "a", uniq_name = "_QFEa"} +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}