diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index b14a1c338e087..18864c72d684f 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -268,24 +268,23 @@ static bool inDeviceContext(mlir::Operation *op) { static int computeWidth(mlir::Location loc, mlir::Type type, fir::KindMapping &kindMap) { auto eleTy = fir::unwrapSequenceType(type); - int width = 0; - if (auto t{mlir::dyn_cast(eleTy)}) { - width = t.getWidth() / 8; - } else if (auto t{mlir::dyn_cast(eleTy)}) { - width = t.getWidth() / 8; - } else if (eleTy.isInteger(1)) { - width = 1; - } else if (auto t{mlir::dyn_cast(eleTy)}) { - int kind = t.getFKind(); - width = kindMap.getLogicalBitsize(kind) / 8; - } else if (auto t{mlir::dyn_cast(eleTy)}) { + if (auto t{mlir::dyn_cast(eleTy)}) + return t.getWidth() / 8; + if (auto t{mlir::dyn_cast(eleTy)}) + return t.getWidth() / 8; + if (eleTy.isInteger(1)) + return 1; + if (auto t{mlir::dyn_cast(eleTy)}) + return kindMap.getLogicalBitsize(t.getFKind()) / 8; + if (auto t{mlir::dyn_cast(eleTy)}) { int elemSize = mlir::cast(t.getElementType()).getWidth() / 8; - width = 2 * elemSize; - } else { - mlir::emitError(loc, "unsupported type"); + return 2 * elemSize; } - return width; + if (auto t{mlir::dyn_cast_or_null(eleTy)}) + return kindMap.getCharacterBitsize(t.getFKind()) / 8; + mlir::emitError(loc, "unsupported type"); + return 0; } struct CUFAllocOpConversion : public mlir::OpRewritePattern { diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir index 25545d1f72f52..49bb5bdf5e6bc 100644 --- a/flang/test/Fir/CUDA/cuda-alloc-free.fir +++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir @@ -83,4 +83,15 @@ gpu.module @cuda_device_mod [#nvvm.target] { // CHECK-LABEL: gpu.func @_QMalloc() kernel // CHECK: fir.alloca !fir.box>> {bindc_name = "a", uniq_name = "_QMallocEa"} +func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} { + %c1 = arith.constant 1 : index + %0 = cuf.alloc !fir.array<10x!fir.char<1>>(%c1 : index) {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFEa"} -> !fir.ref>> + return +} + +// CHECK-LABEL: func.func @_QQalloc_char() +// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c1{{.*}} : index +// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64 +// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr + } // end module diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index 69baf7d15a7d0..b0eb821a54548 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -385,6 +385,25 @@ func.func @_QPdevice_addr_conv() { // CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc + +func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = cuf.alloc !fir.array<10x!fir.char<1>>(%c1 : index) {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFEa"} -> !fir.ref>> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.declare %0(%1) typeparams %c1 {data_attr = #cuf.cuda, uniq_name = "_QFEa"} : (!fir.ref>>, !fir.shape<1>, index) -> !fir.ref>> + %3 = fir.alloca !fir.array<10x!fir.char<1>> {bindc_name = "b", uniq_name = "_QFEb"} + %4 = fir.declare %3(%1) typeparams %c1 {uniq_name = "_QFEb"} : (!fir.ref>>, !fir.shape<1>, index) -> !fir.ref>> + cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>>, !fir.ref>> + cuf.free %2 : !fir.ref>> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QQchar_transfer() +// CHECK: fir.call @_FortranACUFMemAlloc +// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c1{{.*}} : i64 +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none + func.func @_QPdevmul(%arg0: !fir.ref> {fir.bindc_name = "b"}, %arg1: !fir.ref {fir.bindc_name = "wa"}, %arg2: !fir.ref {fir.bindc_name = "wb"}) { %c0_i64 = arith.constant 0 : i64 %c1_i32 = arith.constant 1 : i32 @@ -424,4 +443,5 @@ func.func @_QPdevmul(%arg0: !fir.ref> {fir.bindc_name = "b"} // CHECK: %[[SRC:.*]] = fir.convert %[[ALLOCA]] : (!fir.ref>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none + } // end of module