diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index eaf1d52a9fc1a..79b7a26ef222f 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -91,6 +91,37 @@ struct DeviceExprChecker } }; +struct FindHostArray + : public evaluate::AnyTraverse { + using Result = const Symbol *; + using Base = evaluate::AnyTraverse; + FindHostArray() : Base(*this) {} + using Base::operator(); + Result operator()(const evaluate::Component &x) const { + const Symbol &symbol{x.GetLastSymbol()}; + if (IsAllocatableOrPointer(symbol)) { + if (Result hostArray{(*this)(symbol)}) { + return hostArray; + } + } + return (*this)(x.base()); + } + Result operator()(const Symbol &symbol) const { + if (const auto *details{ + symbol.GetUltimate().detailsIf()}) { + if (details->IsArray() && + (!details->cudaDataAttr() || + (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Device && + *details->cudaDataAttr() != common::CUDADataAttr::Managed && + *details->cudaDataAttr() != common::CUDADataAttr::Unified))) { + return &symbol; + } + } + return nullptr; + } +}; + template static MaybeMsg CheckUnwrappedExpr(const A &x) { if (const auto *expr{parser::Unwrap(x)}) { return DeviceExprChecker{}(expr->typedExpr); @@ -306,22 +337,11 @@ template class DeviceContextChecker { } } template - void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) { - for (const Symbol &sym : CollectCudaSymbols(expr)) { - if (const auto *details = - sym.GetUltimate().detailsIf()) { - if (details->IsArray() && - (!details->cudaDataAttr() || - (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Device && - *details->cudaDataAttr() != common::CUDADataAttr::Managed && - *details->cudaDataAttr() != - common::CUDADataAttr::Unified))) { - context_.Say(source, - "Host array '%s' cannot be present in CUF kernel"_err_en_US, - sym.name()); - } - } + void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) { + if (const Symbol * hostArray{FindHostArray{}(expr)}) { + context_.Say(source, + "Host array '%s' cannot be present in CUF kernel"_err_en_US, + hostArray->name()); } } void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {