diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 6261a4eec4a55..dafacdf1ba0c5 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1102,6 +1102,9 @@ extern template semantics::UnorderedSymbolSet CollectCudaSymbols( // Predicate: does a variable contain a vector-valued subscript (not a triplet)? bool HasVectorSubscript(const Expr &); +// Predicate: does an expression contain constant? +bool HasConstant(const Expr &); + // Utilities for attaching the location of the declaration of a symbol // of interest to a message. Handles the case of USE association gracefully. parser::Message *AttachDeclaration(parser::Message &, const Symbol &); @@ -1319,7 +1322,8 @@ inline bool HasCUDAImplicitTransfer(const Expr &expr) { ++hostSymbols; } } - return hostSymbols > 0 && deviceSymbols > 0; + bool hasConstant{HasConstant(expr)}; + return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0; } } // namespace Fortran::evaluate diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index 15e3e9452894d..a040f7ce79dc1 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1051,6 +1051,23 @@ bool HasVectorSubscript(const Expr &expr) { return HasVectorSubscriptHelper{}(expr); } +// HasConstant() +struct HasConstantHelper : public AnyTraverse { + using Base = AnyTraverse; + HasConstantHelper() : Base{*this} {} + using Base::operator(); + template bool operator()(const Constant &) const { + return true; + } + // Only look for constant not in subscript. + bool operator()(const Subscript &) const { return false; } +}; + +bool HasConstant(const Expr &expr) { + return HasConstantHelper{}(expr); +} + parser::Message *AttachDeclaration( parser::Message &message, const Symbol &symbol) { const Symbol *unhosted{&symbol}; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index cbae6955e2a07..17b58604da3bf 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4416,6 +4416,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { bool hasCUDAImplicitTransfer = Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs); llvm::SmallVector implicitTemps; + if (hasCUDAImplicitTransfer && !isInDeviceContext) implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign); diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index 3b6cd67d9a8fa..cbddcd79c6333 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -38,7 +38,6 @@ subroutine sub1() adev = 10 cdev = 0 - end ! CHECK-LABEL: func.func @_QPsub1() @@ -381,3 +380,17 @@ end subroutine ! CHECK-LABEL: func.func @_QPsub18 ! CHECK-NOT: cuf.data_transfer + +subroutine sub19() + integer, device :: adev(10) + integer :: ahost(10) + ! Implicit data transfer of adev and then addition on the host + ahost = adev + 2 +end subroutine + +! CHECK-LABEL: func.func @_QPsub19() +! CHECK: %[[ADEV_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {data_attr = #cuf.cuda, uniq_name = "_QFsub19Eadev"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""} +! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap>, !fir.shape<1>) -> (!fir.heap>, !fir.heap>) +! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.heap> +! CHECL: hlfir.assign