Skip to content

Commit d1fd369

Browse files
authored
[flang][cuda] Allow unsupported data transfer to be done on the host (#129160)
Some data transfer marked as unsupported can actually be deferred to an assignment on the host when the variables involved are unified or managed.
1 parent b02cfbd commit d1fd369

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) {
13031303
return false;
13041304
}
13051305

1306+
inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
1307+
if (const auto *details =
1308+
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
1309+
if (details->cudaDataAttr() &&
1310+
(*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
1311+
*details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
1312+
return true;
1313+
}
1314+
}
1315+
return false;
1316+
}
1317+
13061318
// Get the number of distinct symbols with CUDA device
13071319
// attribute in the expression.
13081320
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
@@ -1315,12 +1327,42 @@ template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
13151327
return symbols.size();
13161328
}
13171329

1330+
// Get the number of distinct symbols with CUDA managed or unified
1331+
// attribute in the expression.
1332+
template <typename A>
1333+
inline int GetNbOfCUDAManagedOrUnifiedSymbols(const A &expr) {
1334+
semantics::UnorderedSymbolSet symbols;
1335+
for (const Symbol &sym : CollectCudaSymbols(expr)) {
1336+
if (IsCUDAManagedOrUnifiedSymbol(sym)) {
1337+
symbols.insert(sym);
1338+
}
1339+
}
1340+
return symbols.size();
1341+
}
1342+
13181343
// Check if any of the symbols part of the expression has a CUDA device
13191344
// attribute.
13201345
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
13211346
return GetNbOfCUDADeviceSymbols(expr) > 0;
13221347
}
13231348

1349+
// Check if any of the symbols part of the lhs or rhs expression has a CUDA
1350+
// device attribute.
1351+
template <typename A, typename B>
1352+
inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
1353+
int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(lhs)};
1354+
int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
1355+
int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};
1356+
1357+
// Special case where only managed or unifed symbols are involved. This is
1358+
// performed on the host.
1359+
if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 &&
1360+
rhsNbSymbols == 1) {
1361+
return false;
1362+
}
1363+
return HasCUDADeviceAttrs(lhs) || rhsNbSymbols > 0;
1364+
}
1365+
13241366
/// Check if the expression is a mix of host and device variables that require
13251367
/// implicit data transfer.
13261368
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {

flang/lib/Lower/Bridge.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4640,10 +4640,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
46404640

46414641
bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
46424642

4643-
bool isCUDATransfer = (Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs) ||
4644-
Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs)) &&
4645-
!isInDeviceContext;
4643+
bool isCUDATransfer =
4644+
IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;
46464645
bool hasCUDAImplicitTransfer =
4646+
isCUDATransfer &&
46474647
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
46484648
llvm::SmallVector<mlir::Value> implicitTemps;
46494649

flang/lib/Semantics/assignment.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
9898
if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) {
9999
if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) &&
100100
Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) {
101+
if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 &&
102+
GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 &&
103+
GetNbOfCUDADeviceSymbols(rhs) == 1)
104+
return; // This is a special case handled on the host.
101105
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
102106
}
103107
}

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,13 @@ end subroutine
393393
! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
394394
! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
395395
! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
396+
! CHECK: hlfir.assign
397+
398+
subroutine sub20()
399+
integer, managed :: a(10)
400+
a = a + 2 ! ok. No data transfer. Assignment on the host.
401+
end subroutine
402+
403+
! CHECK-LABEL: func.func @_QPsub20()
404+
! CHECK-NOT: cuf.data_transfer
396405
! CHECK: hlfir.assign

0 commit comments

Comments
 (0)