diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp index 9c566c10daff5..a1405d0e85c1d 100644 --- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp @@ -370,11 +370,6 @@ struct CufDataTransferOpConversion mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); - // Only convert cuf.data_transfer with at least one descripor. - if (!mlir::isa(srcTy) && - !mlir::isa(dstTy)) - return failure(); - unsigned mode; if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) { mode = kHostToDevice; @@ -387,7 +382,64 @@ struct CufDataTransferOpConversion auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); + fir::KindMapping kindMap{fir::getKindMapping(mod)}; + mlir::Value modeValue = + builder.createIntegerConstant(loc, builder.getI32Type(), mode); + + // Convert data transfer without any descriptor. + if (!mlir::isa(srcTy) && + !mlir::isa(dstTy)) { + + if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) { + // TODO: scalar to array data transfer. + mlir::emitError(loc, + "not yet implemented: scalar to array data transfer\n"); + return mlir::failure(); + } + + mlir::Type i64Ty = builder.getI64Type(); + mlir::Value nbElement; + if (op.getShape()) { + auto shapeOp = + mlir::dyn_cast(op.getShape().getDefiningOp()); + nbElement = rewriter.create(loc, i64Ty, + shapeOp.getExtents()[0]); + for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) { + auto operand = rewriter.create( + loc, i64Ty, shapeOp.getExtents()[i]); + nbElement = + rewriter.create(loc, nbElement, operand); + } + } else { + if (auto seqTy = mlir::dyn_cast_or_null(dstTy)) + nbElement = builder.createIntegerConstant( + loc, i64Ty, seqTy.getConstantArraySize()); + } + int width = computeWidth(loc, dstTy, kindMap); + mlir::Value widthValue = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width)); + mlir::Value bytes = + nbElement + ? rewriter.create(loc, nbElement, widthValue) + : widthValue; + + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc(loc, + builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); + + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue, + sourceFile, sourceLine)}; + builder.create(loc, func, args); + rewriter.eraseOp(op); + return mlir::success(); + } + // Conversion of data transfer involving at least one descriptor. if (mlir::isa(srcTy) && mlir::isa(dstTy)) { // Transfer between two descriptor. @@ -396,8 +448,6 @@ struct CufDataTransferOpConversion loc, builder); auto fTy = func.getFunctionType(); - mlir::Value modeValue = - builder.createIntegerConstant(loc, builder.getI32Type(), mode); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); @@ -430,8 +480,6 @@ struct CufDataTransferOpConversion builder.create(loc, func, args); rewriter.eraseOp(op); } else { - mlir::Value modeValue = - builder.createIntegerConstant(loc, builder.getI32Type(), mode); // Type used to compute the width. mlir::Type computeType = dstTy; auto seqTy = mlir::dyn_cast(dstTy); @@ -441,7 +489,6 @@ struct CufDataTransferOpConversion computeType = srcTy; seqTy = mlir::dyn_cast(srcTy); } - fir::KindMapping kindMap{fir::getKindMapping(mod)}; int width = computeWidth(loc, computeType, kindMap); mlir::Value nbElement; @@ -509,13 +556,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase { fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false); fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, /*forceUnifiedTBAATree=*/false, *dl); - target.addDynamicallyLegalOp( - [](::cuf::DataTransferOp op) { - mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); - mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); - return !mlir::isa(srcTy) && - !mlir::isa(dstTy); - }); target.addLegalDialect(); cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index f639a6c22b76d..ed894aed5534a 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -70,7 +70,6 @@ func.func @_QPsub4() { cuf.free %4#1 : !fir.ref>>> {data_attr = #cuf.cuda} return } - // CHECK-LABEL: func.func @_QPsub4() // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub4Eadev"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) @@ -137,4 +136,57 @@ func.func @_QPsub5(%arg0: !fir.ref {fir.bindc_name = "n"}) { // CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64 // CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.box, i64, i32, !fir.ref, i32) -> none +func.func @_QPsub6() { + %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda, uniq_name = "_QFsub6Eidev"} -> !fir.ref + %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda, uniq_name = "_QFsub6Eidev"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2 = fir.alloca i32 {bindc_name = "ihost", uniq_name = "_QFsub6Eihost"} + %3:2 = hlfir.declare %2 {uniq_name = "_QFsub6Eihost"} : (!fir.ref) -> (!fir.ref, !fir.ref) + cuf.data_transfer %1#0 to %3#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref, !fir.ref + %4 = fir.load %3#0 : !fir.ref + %5:3 = hlfir.associate %4 {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref, !fir.ref, i1) + cuf.data_transfer %5#0 to %1#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref, !fir.ref + hlfir.end_associate %5#1, %5#2 : !fir.ref, i1 + cuf.free %1#1 : !fir.ref {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub6() +// CHECK: %[[IDEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, uniq_name = "_QFsub6Eidev"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[IHOST:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub6Eihost"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[DST:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %c4{{.*}}, %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +// CHECK: %[[LOAD:.*]] = fir.load %[[IHOST]]#0 : !fir.ref +// CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[LOAD]] {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref, !fir.ref, i1) +// CHECK: %[[DST:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[ASSOC]]#0 : (!fir.ref) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %c4{{.*}}, %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none + +func.func @_QPsub7() { + %c10 = arith.constant 10 : index + %0 = cuf.alloc !fir.array<10xi32> {bindc_name = "idev", data_attr = #cuf.cuda, uniq_name = "_QFsub7Eidev"} -> !fir.ref> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda, uniq_name = "_QFsub7Eidev"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + %c10_0 = arith.constant 10 : index + %3 = fir.alloca !fir.array<10xi32> {bindc_name = "ihost", uniq_name = "_QFsub7Eihost"} + %4 = fir.shape %c10_0 : (index) -> !fir.shape<1> + %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFsub7Eihost"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + cuf.data_transfer %2#0 to %5#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> + cuf.data_transfer %5#0 to %2#0 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> + cuf.free %2#1 : !fir.ref> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub7() +// CHECK: %[[IDEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda, uniq_name = "_QFsub7Eidev"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[IHOST:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub7Eihost"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : i64 +// CHECK: %[[DST:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : i64 +// CHECK: %[[DST:.*]] = fir.convert %[[IDEV]]#0 : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none + } // end of module