@@ -1157,23 +1157,28 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols(
11571157bool HasCUDAImplicitTransfer (const Expr<SomeType> &expr) {
11581158 semantics::UnorderedSymbolSet hostSymbols;
11591159 semantics::UnorderedSymbolSet deviceSymbols;
1160+ semantics::UnorderedSymbolSet cudaSymbols{CollectCudaSymbols (expr)};
11601161
11611162 SymbolVector symbols{GetSymbolVector (expr)};
11621163 std::reverse (symbols.begin (), symbols.end ());
11631164 bool skipNext{false };
11641165 for (const Symbol &sym : symbols) {
1165- bool isComponent{sym.owner ().IsDerivedType ()};
1166- bool skipComponent{false };
1167- if (!skipNext) {
1168- if (IsCUDADeviceSymbol (sym)) {
1169- deviceSymbols.insert (sym);
1170- } else if (isComponent) {
1171- skipComponent = true ; // Component is not device. Look on the base.
1172- } else {
1173- hostSymbols.insert (sym);
1166+ if (cudaSymbols.find (sym) != cudaSymbols.end ()) {
1167+ bool isComponent{sym.owner ().IsDerivedType ()};
1168+ bool skipComponent{false };
1169+ if (!skipNext) {
1170+ if (IsCUDADeviceSymbol (sym)) {
1171+ deviceSymbols.insert (sym);
1172+ } else if (isComponent) {
1173+ skipComponent = true ; // Component is not device. Look on the base.
1174+ } else {
1175+ hostSymbols.insert (sym);
1176+ }
11741177 }
1178+ skipNext = isComponent && !skipComponent;
1179+ } else {
1180+ skipNext = false ;
11751181 }
1176- skipNext = isComponent && !skipComponent;
11771182 }
11781183 bool hasConstant{HasConstant (expr)};
11791184 return (hasConstant || (hostSymbols.size () > 0 )) && deviceSymbols.size () > 0 ;
0 commit comments