Skip to content

Commit 3433e41

Browse files
authored
[flang][cuda] Detect constant on the rhs of data transfer (#117806)
When the rhs expression has some constants and a device symbol, an implicit data transfer needs to be generated for the device symbol and the computation with the constant is done on the host.
1 parent 4d2bc0a commit 3433e41

File tree

4 files changed

+37
-2
lines changed

4 files changed

+37
-2
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,9 @@ extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
11021102
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?
11031103
bool HasVectorSubscript(const Expr<SomeType> &);
11041104

1105+
// Predicate: does an expression contain constant?
1106+
bool HasConstant(const Expr<SomeType> &);
1107+
11051108
// Utilities for attaching the location of the declaration of a symbol
11061109
// of interest to a message. Handles the case of USE association gracefully.
11071110
parser::Message *AttachDeclaration(parser::Message &, const Symbol &);
@@ -1319,7 +1322,8 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
13191322
++hostSymbols;
13201323
}
13211324
}
1322-
return hostSymbols > 0 && deviceSymbols > 0;
1325+
bool hasConstant{HasConstant(expr)};
1326+
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
13231327
}
13241328

13251329
} // namespace Fortran::evaluate

flang/lib/Evaluate/tools.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,23 @@ bool HasVectorSubscript(const Expr<SomeType> &expr) {
10511051
return HasVectorSubscriptHelper{}(expr);
10521052
}
10531053

1054+
// HasConstant()
1055+
struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool,
1056+
/*TraverseAssocEntityDetails=*/false> {
1057+
using Base = AnyTraverse<HasConstantHelper, bool, false>;
1058+
HasConstantHelper() : Base{*this} {}
1059+
using Base::operator();
1060+
template <typename T> bool operator()(const Constant<T> &) const {
1061+
return true;
1062+
}
1063+
// Only look for constant not in subscript.
1064+
bool operator()(const Subscript &) const { return false; }
1065+
};
1066+
1067+
bool HasConstant(const Expr<SomeType> &expr) {
1068+
return HasConstantHelper{}(expr);
1069+
}
1070+
10541071
parser::Message *AttachDeclaration(
10551072
parser::Message &message, const Symbol &symbol) {
10561073
const Symbol *unhosted{&symbol};

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4416,6 +4416,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
44164416
bool hasCUDAImplicitTransfer =
44174417
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
44184418
llvm::SmallVector<mlir::Value> implicitTemps;
4419+
44194420
if (hasCUDAImplicitTransfer && !isInDeviceContext)
44204421
implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
44214422

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ subroutine sub1()
3838
adev = 10
3939

4040
cdev = 0
41-
4241
end
4342

4443
! CHECK-LABEL: func.func @_QPsub1()
@@ -381,3 +380,17 @@ end subroutine
381380

382381
! CHECK-LABEL: func.func @_QPsub18
383382
! CHECK-NOT: cuf.data_transfer
383+
384+
subroutine sub19()
385+
integer, device :: adev(10)
386+
integer :: ahost(10)
387+
! Implicit data transfer of adev and then addition on the host
388+
ahost = adev + 2
389+
end subroutine
390+
391+
! CHECK-LABEL: func.func @_QPsub19()
392+
! CHECK: %[[ADEV_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub19Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
393+
! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
394+
! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
395+
! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
396+
! CHECL: hlfir.assign

0 commit comments

Comments
 (0)