@@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter,
702702 mapSymbolAttributes (converter, var, symMap, stmtCtx, cast);
703703}
704704
705+ bool needCUDAAlloc (const Fortran::semantics::Symbol &sym) {
706+ if (Fortran::semantics::IsDummy (sym))
707+ return false ;
708+ if (const auto *details{
709+ sym.GetUltimate ()
710+ .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
711+ if (details->cudaDataAttr () &&
712+ (*details->cudaDataAttr () == Fortran::common::CUDADataAttr::Device ||
713+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Managed ||
714+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Unified ||
715+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Shared ||
716+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Pinned))
717+ return true ;
718+ const Fortran::semantics::DeclTypeSpec *type{details->type ()};
719+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived ()
720+ : nullptr };
721+ if (derived)
722+ if (FindCUDADeviceAllocatableUltimateComponent (*derived))
723+ return true ;
724+ }
725+ return false ;
726+ }
727+
705728// ===----------------------------------------------------------------===//
706729// Local variables instantiation (not for alias)
707730// ===----------------------------------------------------------------===//
@@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
732755 if (ultimateSymbol.test (Fortran::semantics::Symbol::Flag::CrayPointee))
733756 return builder.create <fir::ZeroOp>(loc, fir::ReferenceType::get (ty));
734757
735- if (Fortran::semantics::NeedCUDAAlloc (ultimateSymbol)) {
758+ if (needCUDAAlloc (ultimateSymbol)) {
736759 cuf::DataAttributeAttr dataAttr =
737760 Fortran::lower::translateSymbolCUFDataAttribute (builder.getContext (),
738761 ultimateSymbol);
@@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
10871110 Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
10881111 symMap);
10891112 auto *builder = &converter.getFirOpBuilder ();
1090- if (Fortran::semantics::NeedCUDAAlloc (var.getSymbol ()) &&
1113+ if (needCUDAAlloc (var.getSymbol ()) &&
10911114 !cuf::isCUDADeviceContext (builder->getRegion ())) {
10921115 cuf::DataAttributeAttr dataAttr =
10931116 Fortran::lower::translateSymbolCUFDataAttribute (builder->getContext (),
0 commit comments