Skip to content

Commit e974ede

Browse files
committed
Passing descriptors by reference to CUDA runtime calls
1 parent 0227b73 commit e974ede

File tree

4 files changed

+23
-34
lines changed

4 files changed

+23
-34
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
3636
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
3737

3838
/// Data transfer from a pointer to a descriptor.
39-
void RTDECL(CUFDataTransferDescPtr)(const Descriptor &dst, void *src,
39+
void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
4040
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
4141
int sourceLine = 0);
4242

4343
/// Data transfer from a descriptor to a pointer.
44-
void RTDECL(CUFDataTransferPtrDesc)(void *dst, const Descriptor &src,
44+
void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4545
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
4646
int sourceLine = 0);
4747

4848
/// Data transfer from a descriptor to a descriptor.
49-
void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dst,
50-
const Descriptor &src, unsigned mode, const char *sourceFile = nullptr,
51-
int sourceLine = 0);
49+
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
50+
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
5251

5352
} // extern "C"
5453
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ struct CUFDataTransferOpConversion
529529
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
530530
mlir::Value sourceLine =
531531
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
532-
mlir::Value dst = builder.loadIfRef(loc, op.getDst());
533-
mlir::Value src = builder.loadIfRef(loc, op.getSrc());
532+
mlir::Value dst = op.getDst();
533+
mlir::Value src = op.getSrc();
534534
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
535535
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
536536
builder.create<fir::CallOp>(loc, func, args);
@@ -603,11 +603,8 @@ struct CUFDataTransferOpConversion
603603
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
604604
mlir::Value sourceLine =
605605
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
606-
mlir::Value dst =
607-
dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
608-
mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
609-
? builder.loadIfRef(loc, op.getSrc())
610-
: op.getSrc();
606+
mlir::Value dst = op.getDst();
607+
mlir::Value src = op.getSrc();
611608
llvm::SmallVector<mlir::Value> args{
612609
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
613610
modeValue, sourceFile, sourceLine)};

flang/runtime/CUDA/memory.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,22 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
7373
CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
7474
}
7575

76-
void RTDEF(CUFDataTransferDescPtr)(const Descriptor &desc, void *addr,
76+
void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
7777
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
7878
Terminator terminator{sourceFile, sourceLine};
7979
terminator.Crash(
8080
"not yet implemented: CUDA data transfer from a pointer to a descriptor");
8181
}
8282

83-
void RTDEF(CUFDataTransferPtrDesc)(void *addr, const Descriptor &desc,
83+
void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
8484
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
8585
Terminator terminator{sourceFile, sourceLine};
8686
terminator.Crash(
8787
"not yet implemented: CUDA data transfer from a descriptor to a pointer");
8888
}
8989

90-
void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dstDesc,
91-
const Descriptor &srcDesc, unsigned mode, const char *sourceFile,
92-
int sourceLine) {
90+
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
91+
unsigned mode, const char *sourceFile, int sourceLine) {
9392
Terminator terminator{sourceFile, sourceLine};
9493
terminator.Crash(
9594
"not yet implemented: CUDA data transfer between two descriptors");

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@ func.func @_QPsub1() {
1515
// CHECK-LABEL: func.func @_QPsub1()
1616
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
1717
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
18-
// CHECK: %[[AHOST_LOAD:.*]] = fir.load %[[AHOST]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
19-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
20-
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
21-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
22-
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
18+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
19+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
20+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
2321

2422
func.func @_QPsub2() {
2523
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub2Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -76,19 +74,17 @@ func.func @_QPsub4() {
7674
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
7775
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
7876
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
79-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
80-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
77+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
8178
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
8279
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
83-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
80+
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
8481
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
8582
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
8683
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
87-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
8884
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
89-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
85+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
9086
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
91-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
87+
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
9288

9389
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
9490
%0 = fir.dummy_scope : !fir.dscope
@@ -122,19 +118,17 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
122118
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
123119
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
124120
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
125-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
126-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
121+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
127122
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
128123
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
129-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
124+
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
130125
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
131126
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
132127
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
133-
// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
134128
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
135-
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
129+
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
136130
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
137-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
131+
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
138132

139133
func.func @_QPsub6() {
140134
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>

0 commit comments

Comments
 (0)