diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index f94981011b6e5..44e0e73028bf7 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) { return false; } +inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) { + if (const auto *details = + sym.GetUltimate().detailsIf()) { + if (details->cudaDataAttr() && + (*details->cudaDataAttr() == common::CUDADataAttr::Managed || + *details->cudaDataAttr() == common::CUDADataAttr::Unified)) { + return true; + } + } + return false; +} + // Get the number of distinct symbols with CUDA device // attribute in the expression. template inline int GetNbOfCUDADeviceSymbols(const A &expr) { @@ -1315,12 +1327,42 @@ template inline int GetNbOfCUDADeviceSymbols(const A &expr) { return symbols.size(); } +// Get the number of distinct symbols with CUDA managed or unified +// attribute in the expression. +template +inline int GetNbOfCUDAManagedOrUnifiedSymbols(const A &expr) { + semantics::UnorderedSymbolSet symbols; + for (const Symbol &sym : CollectCudaSymbols(expr)) { + if (IsCUDAManagedOrUnifiedSymbol(sym)) { + symbols.insert(sym); + } + } + return symbols.size(); +} + // Check if any of the symbols part of the expression has a CUDA device // attribute. template inline bool HasCUDADeviceAttrs(const A &expr) { return GetNbOfCUDADeviceSymbols(expr) > 0; } +// Check if any of the symbols part of the lhs or rhs expression has a CUDA +// device attribute. +template +inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) { + int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(lhs)}; + int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(rhs)}; + int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)}; + + // Special case where only managed or unifed symbols are involved. This is + // performed on the host. + if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 && + rhsNbSymbols == 1) { + return false; + } + return HasCUDADeviceAttrs(lhs) || rhsNbSymbols > 0; +} + /// Check if the expression is a mix of host and device variables that require /// implicit data transfer. inline bool HasCUDAImplicitTransfer(const Expr &expr) { diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index f824e4c621c8e..cc19f335cd017 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4640,10 +4640,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder); - bool isCUDATransfer = (Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs) || - Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs)) && - !isInDeviceContext; + bool isCUDATransfer = + IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext; bool hasCUDAImplicitTransfer = + isCUDATransfer && Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs); llvm::SmallVector implicitTemps; diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp index 627983d19a822..8de20d3126a6c 100644 --- a/flang/lib/Semantics/assignment.cpp +++ b/flang/lib/Semantics/assignment.cpp @@ -98,6 +98,10 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) { if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) { if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) && Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) { + if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 && + GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 && + GetNbOfCUDADeviceSymbols(rhs) == 1) + return; // This is a special case handled on the host. context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US); } } diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index 1c03a76cae76a..fa324abb137ec 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -393,4 +393,13 @@ end subroutine ! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""} ! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap>, !fir.shape<1>) -> (!fir.heap>, !fir.heap>) ! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.heap> +! CHECK: hlfir.assign + +subroutine sub20() + integer, managed :: a(10) + a = a + 2 ! ok. No data transfer. Assignment on the host. +end subroutine + +! CHECK-LABEL: func.func @_QPsub20() +! CHECK-NOT: cuf.data_transfer ! CHECK: hlfir.assign