Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -1286,16 +1286,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
const std::optional<ActualArgument> &, const std::string &procName,
const std::string &argName);

inline bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
return false;
}
}
return true;
}
bool CanCUDASymbolHaveSaveAttr(const Symbol &sym);

inline bool IsCUDADeviceSymbol(const Symbol &sym) {
if (const auto *details =
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Semantics/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ DirectComponentIterator::const_iterator FindAllocatableOrPointerDirectComponent(
const DerivedTypeSpec &);
PotentialComponentIterator::const_iterator
FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &);
UltimateComponentIterator::const_iterator
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &);

// The LabelEnforce class (given a set of labels) provides an error message if
// there is a branch to a label which is not in the given set.
Expand Down
29 changes: 24 additions & 5 deletions flang/lib/Evaluate/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,25 @@ bool IsAutomatic(const Symbol &original) {
return false;
}

bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{
type ? type->AsDerived() : nullptr};
if (derived) {
if (auto iter{FindCUDADeviceAllocatableUltimateComponent(*derived)}) {
return false;
}
}
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
return false;
}
}
return true;
}

bool IsSaved(const Symbol &original) {
const Symbol &symbol{GetAssociationRoot(original)};
const Scope &scope{symbol.owner()};
Expand All @@ -2195,7 +2214,7 @@ bool IsSaved(const Symbol &original) {
} else if (scopeKind == Scope::Kind::Module ||
(scopeKind == Scope::Kind::MainProgram &&
(symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)) &&
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol))) {
CanCUDASymbolHaveSaveAttr(symbol))) {
// 8.5.16p4
// In main programs, implied SAVE matters only for pointer
// initialization targets and coarrays.
Expand All @@ -2205,7 +2224,7 @@ bool IsSaved(const Symbol &original) {
(features.IsEnabled(
common::LanguageFeature::SaveBigMainProgramVariables) &&
symbol.size() > 32)) &&
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol)) {
CanCUDASymbolHaveSaveAttr(symbol)) {
// With SaveBigMainProgramVariables, keeping all unsaved main program
// variables of 32 bytes or less on the stack allows keeping numerical and
// logical scalars, small scalar characters or derived, small arrays, and
Expand All @@ -2223,15 +2242,15 @@ bool IsSaved(const Symbol &original) {
} else if (symbol.test(Symbol::Flag::InDataStmt)) {
return true;
} else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()};
object && object->init()) {
object && object->init()) {
return true;
} else if (IsProcedurePointer(symbol) && symbol.has<ProcEntityDetails>() &&
symbol.get<ProcEntityDetails>().init()) {
return true;
} else if (scope.hasSAVE()) {
return true; // bare SAVE statement
} else if (const Symbol * block{FindCommonBlockContaining(symbol)};
block && block->attrs().test(Attr::SAVE)) {
} else if (const Symbol *block{FindCommonBlockContaining(symbol)};
block && block->attrs().test(Attr::SAVE)) {
return true; // in COMMON with SAVE
} else {
return false;
Expand Down
21 changes: 20 additions & 1 deletion flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,19 @@ const Scope *FindCUDADeviceContext(const Scope *scope) {
});
}

bool IsDeviceAllocatable(const Symbol &symbol) {
if (IsAllocatable(symbol)) {
if (const auto *details =
symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
return true;
}
}
}
return false;
}

std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
const auto *object{
symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
Expand Down Expand Up @@ -1426,6 +1439,12 @@ FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &derived) {
potentials.begin(), potentials.end(), IsPolymorphicAllocatable);
}

UltimateComponentIterator::const_iterator
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) {
UltimateComponentIterator ultimates{derived};
return std::find_if(ultimates.begin(), ultimates.end(), IsDeviceAllocatable);
}

const Symbol *FindUltimateComponent(const DerivedTypeSpec &derived,
const std::function<bool(const Symbol &)> &predicate) {
UltimateComponentIterator ultimates{derived};
Expand Down Expand Up @@ -1788,4 +1807,4 @@ bool HadUseError(
}
}

} // namespace Fortran::semantics
} // namespace Fortran::semantics
15 changes: 15 additions & 0 deletions flang/test/Lower/CUDA/cuda-derived.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

module m1
type ty_device
integer, device, allocatable, dimension(:) :: x
end type
end module

program main
use m1
type(ty_device) :: a
end

! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
Loading