Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) {
return false;
}

inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
(*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
*details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
return true;
}
}
return false;
}

// Get the number of distinct symbols with CUDA device
// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
Expand All @@ -1315,12 +1327,42 @@ template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
return symbols.size();
}

// Get the number of distinct symbols with CUDA managed or unified
// attribute in the expression.
template <typename A>
inline int GetNbOfCUDAManagedOrUnifiedSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (IsCUDAManagedOrUnifiedSymbol(sym)) {
symbols.insert(sym);
}
}
return symbols.size();
}

// Check if any of the symbols part of the expression has a CUDA device
// attribute.
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
return GetNbOfCUDADeviceSymbols(expr) > 0;
}

// Check if any of the symbols part of the lhs or rhs expression has a CUDA
// device attribute.
template <typename A, typename B>
inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(lhs)};
int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};

// Special case where only managed or unifed symbols are involved. This is
// performed on the host.
if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 &&
rhsNbSymbols == 1) {
return false;
}
return HasCUDADeviceAttrs(lhs) || rhsNbSymbols > 0;
}

/// Check if the expression is a mix of host and device variables that require
/// implicit data transfer.
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
Expand Down
6 changes: 3 additions & 3 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4640,10 +4640,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {

bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder);

bool isCUDATransfer = (Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs) ||
Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs)) &&
!isInDeviceContext;
bool isCUDATransfer =
IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;
bool hasCUDAImplicitTransfer =
isCUDATransfer &&
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
llvm::SmallVector<mlir::Value> implicitTemps;

Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Semantics/assignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) {
if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) &&
Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) {
if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 &&
GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 &&
GetNbOfCUDADeviceSymbols(rhs) == 1)
return; // This is a special case handled on the host.
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
}
}
Expand Down
9 changes: 9 additions & 0 deletions flang/test/Lower/CUDA/cuda-data-transfer.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,13 @@ end subroutine
! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
! CHECK: hlfir.assign

subroutine sub20()
integer, managed :: a(10)
a = a + 2 ! ok. No data transfer. Assignment on the host.
end subroutine

! CHECK-LABEL: func.func @_QPsub20()
! CHECK-NOT: cuf.data_transfer
! CHECK: hlfir.assign