Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/CUDA/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a descriptor to a global descriptor.
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

} // extern "C"
} // namespace Fortran::runtime::cuda
#endif // FORTRAN_RUNTIME_CUDA_MEMORY_H_
17 changes: 15 additions & 2 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,16 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};

static bool isDstGlobal(cuf::DataTransferOp op) {
if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
return true;
if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
return true;
return false;
}

struct CUFDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -522,8 +532,11 @@ struct CUFDataTransferOpConversion
mlir::isa<fir::BaseBoxType>(dstTy)) {
// Transfer between two descriptor.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
loc, builder);
isDstGlobal(op)
? fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferGlobalDescDesc)>(loc, builder)
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
loc, builder);

auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
Expand Down
14 changes: 14 additions & 0 deletions flang/runtime/CUDA/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "flang/Runtime/CUDA/memory.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/assign.h"

#include "cuda_runtime.h"
Expand Down Expand Up @@ -123,5 +124,18 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}

void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
RTNAME(CUFDataTransferDescDesc)
(dstDesc, srcDesc, mode, sourceFile, sourceLine);
if ((mode == kHostToDevice) || (mode == kDeviceToDevice)) {
void *deviceAddr{
RTNAME(CUFGetDeviceAddress)((void *)dstDesc, sourceFile, sourceLine)};
RTNAME(CUFDescriptorSync)
((Descriptor *)deviceAddr, srcDesc, sourceFile, sourceLine);
}
}
}
} // namespace Fortran::runtime::cuda
25 changes: 25 additions & 0 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,29 @@ func.func @_QPsub9() {
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none

fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
%c0 = arith.constant 0 : index
%0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
}

func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
%0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
%1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
%2 = fir.address_of(@_QFEahost) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
%3:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
cuf.data_transfer %3#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
return
}

// CHECK-LABEL: func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"}
// CHECK: %[[GLOBAL_ADDRESS:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
// CHECK: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL_ADDRESS]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none


} // end of module
Loading