diff --git a/flang/include/flang/Runtime/CUDA/pointer.h b/flang/include/flang/Runtime/CUDA/pointer.h index 2197d85f4b93e..78c7a1a92b7ea 100644 --- a/flang/include/flang/Runtime/CUDA/pointer.h +++ b/flang/include/flang/Runtime/CUDA/pointer.h @@ -21,6 +21,12 @@ int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t stream = -1, bool hasStat = false, const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr, int sourceLine = 0); +/// Perform allocation of the descriptor with synchronization of it when +/// necessary. +int RTDECL(CUFPointerAllocateSync)(Descriptor &, int64_t stream = -1, + bool hasStat = false, const Descriptor *errMsg = nullptr, + const char *sourceFile = nullptr, int sourceLine = 0); + /// Perform allocation of the descriptor without synchronization. Assign data /// from source. int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer, @@ -28,6 +34,13 @@ int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer, const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr, int sourceLine = 0); +/// Perform allocation of the descriptor with synchronization of it when +/// necessary. Assign data from source. +int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer, + const Descriptor &source, int64_t stream = -1, bool hasStat = false, + const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr, + int sourceLine = 0); + } // extern "C" } // namespace Fortran::runtime::cuda diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index b0d6b0f0993a6..7292ce741b85b 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -172,18 +172,22 @@ struct CUFAllocateOpConversion isPointer = true; if (hasDoubleDescriptors(op)) { - if (isPointer) - TODO(loc, "pointer allocation with double descriptors"); // Allocation for module variable are done with custom runtime entry point // so the descriptors can be synchronized. mlir::func::FuncOp func; - if (op.getSource()) - func = fir::runtime::getRuntimeFunc(loc, builder); - else + if (op.getSource()) { + func = isPointer ? fir::runtime::getRuntimeFunc(loc, builder) + : fir::runtime::getRuntimeFunc(loc, builder); + } else { func = - fir::runtime::getRuntimeFunc( - loc, builder); + isPointer + ? fir::runtime::getRuntimeFunc( + loc, builder) + : fir::runtime::getRuntimeFunc(loc, builder); + } return convertOpToCall(op, rewriter, func); } diff --git a/flang/runtime/CUDA/pointer.cpp b/flang/runtime/CUDA/pointer.cpp index 35f373b0a56c3..3252410bd8d2c 100644 --- a/flang/runtime/CUDA/pointer.cpp +++ b/flang/runtime/CUDA/pointer.cpp @@ -10,6 +10,7 @@ #include "../assign-impl.h" #include "../stat.h" #include "../terminator.h" +#include "flang/Runtime/CUDA/descriptor.h" #include "flang/Runtime/CUDA/memmove-function.h" #include "flang/Runtime/pointer.h" @@ -35,6 +36,24 @@ int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t stream, bool hasStat, return stat; } +int RTDEF(CUFPointerAllocateSync)(Descriptor &desc, int64_t stream, + bool hasStat, const Descriptor *errMsg, const char *sourceFile, + int sourceLine) { + int stat{RTNAME(CUFPointerAllocate)( + desc, stream, hasStat, errMsg, sourceFile, sourceLine)}; +#ifndef RT_DEVICE_COMPILATION + // Descriptor synchronization is only done when the allocation is done + // from the host. + if (stat == StatOk) { + void *deviceAddr{ + RTNAME(CUFGetDeviceAddress)((void *)&desc, sourceFile, sourceLine)}; + RTNAME(CUFDescriptorSync) + ((Descriptor *)deviceAddr, &desc, sourceFile, sourceLine); + } +#endif + return stat; +} + int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer, const Descriptor &source, int64_t stream, bool hasStat, const Descriptor *errMsg, const char *sourceFile, int sourceLine) { @@ -48,6 +67,19 @@ int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer, return stat; } +int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer, + const Descriptor &source, int64_t stream, bool hasStat, + const Descriptor *errMsg, const char *sourceFile, int sourceLine) { + int stat{RTNAME(CUFPointerAllocateSync)( + pointer, stream, hasStat, errMsg, sourceFile, sourceLine)}; + if (stat == StatOk) { + Terminator terminator{sourceFile, sourceLine}; + Fortran::runtime::DoFromSourceAssign( + pointer, source, terminator, &MemmoveHostToDevice); + } + return stat; +} + RT_EXT_API_GROUP_END } // extern "C" diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir index 804bb8636685d..b8457b846716e 100644 --- a/flang/test/Fir/CUDA/cuda-allocate.fir +++ b/flang/test/Fir/CUDA/cuda-allocate.fir @@ -198,16 +198,61 @@ func.func @_QPpointer_source() { %c0_i32 = arith.constant 0 : i32 %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = fir.alloca !fir.box>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"} - %4 = fir.declare %0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref>>>) -> !fir.ref>>> - %5 = cuf.alloc !fir.box>> {bindc_name = "a_d", data_attr = #cuf.cuda, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref>>> - %7 = fir.declare %5 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref>>>) -> !fir.ref>>> - %8 = fir.load %4 : !fir.ref>>> - %22 = cuf.allocate %7 : !fir.ref>>> source(%8 : !fir.box>>) {data_attr = #cuf.cuda} -> i32 + %0 = fir.alloca !fir.box>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"} + %4 = fir.declare %0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref>>>) -> !fir.ref>>> + %5 = cuf.alloc !fir.box>> {bindc_name = "a_d", data_attr = #cuf.cuda, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref>>> + %7 = fir.declare %5 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref>>>) -> !fir.ref>>> + %8 = fir.load %4 : !fir.ref>>> + %22 = cuf.allocate %7 : !fir.ref>>> source(%8 : !fir.box>>) {data_attr = #cuf.cuda} -> i32 return } // CHECK-LABEL: func.func @_QPpointer_source() // CHECK: _FortranACUFPointerAllocateSource +fir.global @_QMdataEb2 {data_attr = #cuf.cuda} : !fir.box>> { + %c0 = arith.constant 0 : index + %0 = fir.zero_bits !fir.ptr> + %1 = fir.shape %c0 : (index) -> !fir.shape<1> + %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr>, !fir.shape<1>) -> !fir.box>> + fir.has_value %2 : !fir.box>> +} + +func.func @_QQpointer_sync() attributes {fir.bindc_name = "test"} { + %c0_i32 = arith.constant 0 : i32 + %c10_i32 = arith.constant 10 : i32 + %c1 = arith.constant 1 : index + %0 = fir.address_of(@_QMdataEb2) : !fir.ref>>> + %1 = fir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMdataEb"} : (!fir.ref>>>) -> (!fir.ref>>>) + %2 = fir.convert %1 : (!fir.ref>>>) -> !fir.ref> + %3 = fir.convert %c1 : (index) -> i64 + %4 = fir.convert %c10_i32 : (i32) -> i64 + fir.call @_FortranAAllocatableSetBounds(%2, %c0_i32, %3, %4) fastmath : (!fir.ref>, i32, i64, i64) -> () + %6 = cuf.allocate %1 : !fir.ref>>> {data_attr = #cuf.cuda} -> i32 + return +} + +// CHECK-LABEL: func.func @_QQpointer_sync() +// CHECK: _FortranACUFPointerAllocateSync + +fir.global @_QMmod1Ea_d2 {data_attr = #cuf.cuda} : !fir.box>> { + %c0 = arith.constant 0 : index + %0 = fir.zero_bits !fir.ptr> + %1 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2> + %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr>, !fir.shape<2>) -> !fir.box>> + fir.has_value %2 : !fir.box>> +} +func.func @_QMmod1Ppointer_source_global() { + %0 = fir.address_of(@_QMmod1Ea_d2) : !fir.ref>>> + %1 = fir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMmod1Ea_d"} : (!fir.ref>>>) -> !fir.ref>>> + %2 = fir.alloca !fir.box>> {bindc_name = "a", uniq_name = "_QMmod1Fallocate_source_globalEa"} + %6 = fir.declare %2 {fortran_attrs = #fir.var_attrs, uniq_name = "_QMmod1Fallocate_source_globalEa"} : (!fir.ref>>>) -> !fir.ref>>> + %7 = fir.load %6 : !fir.ref>>> + %21 = cuf.allocate %1 : !fir.ref>>> source(%7 : !fir.box>>) {data_attr = #cuf.cuda} -> i32 + return +} + +// CHECK-LABEL: func.func @_QMmod1Ppointer_source_global() +// CHECK: fir.call @_FortranACUFPointerAllocateSourceSync + } // end of module