@@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) {
13031303 return false ;
13041304}
13051305
1306+ inline bool IsCUDAManagedOrUnifiedSymbol (const Symbol &sym) {
1307+ if (const auto *details =
1308+ sym.GetUltimate ().detailsIf <semantics::ObjectEntityDetails>()) {
1309+ if (details->cudaDataAttr () &&
1310+ (*details->cudaDataAttr () == common::CUDADataAttr::Managed ||
1311+ *details->cudaDataAttr () == common::CUDADataAttr::Unified)) {
1312+ return true ;
1313+ }
1314+ }
1315+ return false ;
1316+ }
1317+
13061318// Get the number of distinct symbols with CUDA device
13071319// attribute in the expression.
13081320template <typename A> inline int GetNbOfCUDADeviceSymbols (const A &expr) {
@@ -1315,12 +1327,42 @@ template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
13151327 return symbols.size ();
13161328}
13171329
1330+ // Get the number of distinct symbols with CUDA managed or unified
1331+ // attribute in the expression.
1332+ template <typename A>
1333+ inline int GetNbOfCUDAManagedOrUnifiedSymbols (const A &expr) {
1334+ semantics::UnorderedSymbolSet symbols;
1335+ for (const Symbol &sym : CollectCudaSymbols (expr)) {
1336+ if (IsCUDAManagedOrUnifiedSymbol (sym)) {
1337+ symbols.insert (sym);
1338+ }
1339+ }
1340+ return symbols.size ();
1341+ }
1342+
13181343// Check if any of the symbols part of the expression has a CUDA device
13191344// attribute.
13201345template <typename A> inline bool HasCUDADeviceAttrs (const A &expr) {
13211346 return GetNbOfCUDADeviceSymbols (expr) > 0 ;
13221347}
13231348
1349+ // Check if any of the symbols part of the lhs or rhs expression has a CUDA
1350+ // device attribute.
1351+ template <typename A, typename B>
1352+ inline bool IsCUDADataTransfer (const A &lhs, const B &rhs) {
1353+ int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols (lhs)};
1354+ int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols (rhs)};
1355+ int rhsNbSymbols{GetNbOfCUDADeviceSymbols (rhs)};
1356+
1357+ // Special case where only managed or unifed symbols are involved. This is
1358+ // performed on the host.
1359+ if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 &&
1360+ rhsNbSymbols == 1 ) {
1361+ return false ;
1362+ }
1363+ return HasCUDADeviceAttrs (lhs) || rhsNbSymbols > 0 ;
1364+ }
1365+
13241366// / Check if the expression is a mix of host and device variables that require
13251367// / implicit data transfer.
13261368inline bool HasCUDAImplicitTransfer (const Expr<SomeType> &expr) {
0 commit comments