Skip to content

Commit db69d69

Browse files
authored
[flang][cuda] Support data transfer from descriptor to a pointer (#115023)
Data transfer from a variable with a descriptor to a pointer. We create a descriptor for the pointer so we can use the flang runtime to perform the transfer. The Assign function handles all corner cases. We add a new entry points `CUFDataTransferDescDescNoRealloc` to avoid reallocation since the variable on the LHS is not an allocatable.
1 parent 17d9565 commit db69d69

File tree

4 files changed

+46
-49
lines changed

4 files changed

+46
-49
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4444
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
4545
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
4646

47+
/// Data transfer from a descriptor to a descriptor.
48+
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
49+
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
50+
4751
/// Data transfer from a descriptor to a global descriptor.
4852
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
4953
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
581581
builder.create<fir::CallOp>(loc, func, args);
582582
rewriter.eraseOp(op);
583583
} else {
584-
// Type used to compute the width.
585-
mlir::Type computeType = dstTy;
586-
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
587-
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
588-
computeType = srcTy;
589-
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
590-
}
591-
int width = computeWidth(loc, computeType, kindMap);
584+
// Transfer from a descriptor.
592585

593-
mlir::Value nbElement;
594-
mlir::Type idxTy = rewriter.getIndexType();
595-
if (!op.getShape()) {
596-
nbElement = rewriter.create<mlir::arith::ConstantOp>(
597-
loc, idxTy,
598-
rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
599-
} else {
600-
auto shapeOp =
601-
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
602-
nbElement =
603-
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
604-
for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
605-
auto operand =
606-
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
607-
nbElement =
608-
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
609-
}
610-
}
586+
mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
587+
mlir::Type boxTy = fir::BoxType::get(dstTy);
588+
llvm::SmallVector<mlir::Value> lenParams;
589+
mlir::Value box =
590+
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
591+
/*slice=*/nullptr, lenParams,
592+
/*tdesc=*/nullptr);
593+
mlir::Value memBox = builder.createTemporary(loc, box.getType());
594+
builder.create<fir::StoreOp>(loc, box, memBox);
611595

612-
mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
613-
loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
614-
mlir::Value bytes =
615-
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
596+
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
597+
CUFDataTransferDescDescNoRealloc)>(loc, builder);
616598

617-
mlir::func::FuncOp func =
618-
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
619-
loc, builder);
620599
auto fTy = func.getFunctionType();
621600
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
622601
mlir::Value sourceLine =
623-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
624-
mlir::Value dst = op.getDst();
625-
mlir::Value src = op.getSrc();
602+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
626603
llvm::SmallVector<mlir::Value> args{
627-
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
604+
fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
628605
modeValue, sourceFile, sourceLine)};
629606
builder.create<fir::CallOp>(loc, func, args);
630607
rewriter.eraseOp(op);

flang/runtime/CUDA/memory.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
120120
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
121121
}
122122

123+
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
124+
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
125+
int sourceLine) {
126+
MemmoveFct memmoveFct;
127+
Terminator terminator{sourceFile, sourceLine};
128+
if (mode == kHostToDevice) {
129+
memmoveFct = &MemmoveHostToDevice;
130+
} else if (mode == kDeviceToHost) {
131+
memmoveFct = &MemmoveDeviceToHost;
132+
} else if (mode == kDeviceToDevice) {
133+
memmoveFct = &MemmoveDeviceToDevice;
134+
} else {
135+
terminator.Crash("host to host copy not supported");
136+
}
137+
Fortran::runtime::Assign(
138+
*dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
139+
}
140+
123141
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
124142
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
125143
int sourceLine) {

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func.func @_QPsub4() {
7373
return
7474
}
7575
// CHECK-LABEL: func.func @_QPsub4()
76+
// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
7677
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
7778
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
7879
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -81,13 +82,11 @@ func.func @_QPsub4() {
8182
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
8283
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
8384
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
84-
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
85-
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
86-
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
87-
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
85+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
86+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
87+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
8888
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
89-
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
90-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
89+
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
9190

9291
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
9392
%0 = fir.dummy_scope : !fir.dscope
@@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
115114
}
116115

117116
// CHECK-LABEL: func.func @_QPsub5
117+
// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
118118
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
119119
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
120120
// CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
@@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
124124
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
125125
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
126126
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
127-
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
128-
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
129-
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
130-
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
127+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
128+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
129+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
131130
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
132-
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
133-
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
131+
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
134132

135133
func.func @_QPsub6() {
136134
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>

0 commit comments

Comments
 (0)