Skip to content

Commit 73216cd

Browse files
authored
[flang] Rework CUDA kernel DO host array check (#116301)
Don't worry about derived type components unless they are pointers or allocatables.
1 parent e72209d commit 73216cd

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

flang/lib/Semantics/check-cuda.cpp

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,37 @@ struct DeviceExprChecker
9191
}
9292
};
9393

94+
struct FindHostArray
95+
: public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
96+
using Result = const Symbol *;
97+
using Base = evaluate::AnyTraverse<FindHostArray, Result>;
98+
FindHostArray() : Base(*this) {}
99+
using Base::operator();
100+
Result operator()(const evaluate::Component &x) const {
101+
const Symbol &symbol{x.GetLastSymbol()};
102+
if (IsAllocatableOrPointer(symbol)) {
103+
if (Result hostArray{(*this)(symbol)}) {
104+
return hostArray;
105+
}
106+
}
107+
return (*this)(x.base());
108+
}
109+
Result operator()(const Symbol &symbol) const {
110+
if (const auto *details{
111+
symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
112+
if (details->IsArray() &&
113+
(!details->cudaDataAttr() ||
114+
(details->cudaDataAttr() &&
115+
*details->cudaDataAttr() != common::CUDADataAttr::Device &&
116+
*details->cudaDataAttr() != common::CUDADataAttr::Managed &&
117+
*details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
118+
return &symbol;
119+
}
120+
}
121+
return nullptr;
122+
}
123+
};
124+
94125
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
95126
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
96127
return DeviceExprChecker{}(expr->typedExpr);
@@ -306,22 +337,11 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
306337
}
307338
}
308339
template <typename A>
309-
void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) {
310-
for (const Symbol &sym : CollectCudaSymbols(expr)) {
311-
if (const auto *details =
312-
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
313-
if (details->IsArray() &&
314-
(!details->cudaDataAttr() ||
315-
(details->cudaDataAttr() &&
316-
*details->cudaDataAttr() != common::CUDADataAttr::Device &&
317-
*details->cudaDataAttr() != common::CUDADataAttr::Managed &&
318-
*details->cudaDataAttr() !=
319-
common::CUDADataAttr::Unified))) {
320-
context_.Say(source,
321-
"Host array '%s' cannot be present in CUF kernel"_err_en_US,
322-
sym.name());
323-
}
324-
}
340+
void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
341+
if (const Symbol * hostArray{FindHostArray{}(expr)}) {
342+
context_.Say(source,
343+
"Host array '%s' cannot be present in CUF kernel"_err_en_US,
344+
hostArray->name());
325345
}
326346
}
327347
void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {

0 commit comments

Comments
 (0)