From 02bd71cd5c9526ed50527577bbf321990e563ae9 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 6 Nov 2024 14:45:23 -0800 Subject: [PATCH] [flang][cuda] Support scalar to array data transfer --- .../Optimizer/Transforms/CUFOpConversion.cpp | 105 +++++++++++------- flang/test/Fir/CUDA/cuda-data-transfer.fir | 14 +++ 2 files changed, 81 insertions(+), 38 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 6187ca03d2c41..881f54133ce73 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) { return mlir::Value{}; } +static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, + cuf::DataTransferOp op, + const mlir::SymbolTable &symtab) { + auto mod = op->getParentOfType(); + mlir::Location loc = op.getLoc(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Value addr; + mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); + if (fir::isa_trivial(srcTy) && + mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) { + // Put constant in memory if it is not. + mlir::Value alloc = builder.createTemporary(loc, srcTy); + builder.create(loc, op.getSrc(), alloc); + addr = alloc; + } else { + addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); + } + llvm::SmallVector lenParams; + mlir::Type boxTy = fir::BoxType::get(srcTy); + mlir::Value box = + builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()), + /*slice=*/nullptr, lenParams, + /*tdesc=*/nullptr); + mlir::Value src = builder.createTemporary(loc, box.getType()); + builder.create(loc, box, src); + return src; +} + +static mlir::Value emboxDst(mlir::PatternRewriter &rewriter, + cuf::DataTransferOp op, + const mlir::SymbolTable &symtab) { + auto mod = op->getParentOfType(); + mlir::Location loc = op.getLoc(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); + mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab); + mlir::Type dstBoxTy = fir::BoxType::get(dstTy); + llvm::SmallVector lenParams; + mlir::Value dstBox = + builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()), + /*slice=*/nullptr, lenParams, + /*tdesc=*/nullptr); + mlir::Value dst = builder.createTemporary(loc, dstBox.getType()); + builder.create(loc, dstBox, dst); + return dst; +} + struct CUFDataTransferOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion !mlir::isa(dstTy)) { if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) { - // TODO: scalar to array data transfer. - mlir::emitError(loc, - "not yet implemented: scalar to array data transfer\n"); - return mlir::failure(); + // Initialization of an array from a scalar value should be implemented + // via a kernel launch. Use the flan runtime via the Assign function + // until we have more infrastructure. + mlir::Value src = emboxSrc(rewriter, op, symtab); + mlir::Value dst = emboxDst(rewriter, op, symtab); + mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; + builder.create(loc, func, args); + rewriter.eraseOp(op); + return mlir::success(); } mlir::Type i64Ty = builder.getI64Type(); @@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion mlir::Value dst = op.getDst(); mlir::Value src = op.getSrc(); - if (!mlir::isa(srcTy)) { - // If src is not a descriptor, create one. - mlir::Value addr; - if (fir::isa_trivial(srcTy) && - mlir::matchPattern(op.getSrc().getDefiningOp(), - mlir::m_Constant())) { - // Put constant in memory if it is not. - mlir::Value alloc = builder.createTemporary(loc, srcTy); - builder.create(loc, op.getSrc(), alloc); - addr = alloc; - } else { - addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); - } - mlir::Type boxTy = fir::BoxType::get(srcTy); - llvm::SmallVector lenParams; - mlir::Value box = - builder.createBox(loc, boxTy, addr, getShapeFromDecl(src), - /*slice=*/nullptr, lenParams, - /*tdesc=*/nullptr); - mlir::Value memBox = builder.createTemporary(loc, box.getType()); - builder.create(loc, box, memBox); - src = memBox; - } + if (!mlir::isa(srcTy)) + src = emboxSrc(rewriter, op, symtab); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); @@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion rewriter.eraseOp(op); } else { // Transfer from a descriptor. - - mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab); - mlir::Type boxTy = fir::BoxType::get(dstTy); - llvm::SmallVector lenParams; - mlir::Value box = - builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()), - /*slice=*/nullptr, lenParams, - /*tdesc=*/nullptr); - mlir::Value memBox = builder.createTemporary(loc, box.getType()); - builder.create(loc, box, memBox); + mlir::Value dst = emboxDst(rewriter, op, symtab); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); @@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); llvm::SmallVector args{ - fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(), + fir::runtime::createArguments(builder, loc, fTy, dst, op.getSrc(), modeValue, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index d9588942b21e8..8497aee2e2cf9 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -281,4 +281,18 @@ func.func @_QPdesc_global_ptr() { // CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> none +func.func @_QPscalar_to_array() { + %c1_i32 = arith.constant 1 : i32 + %c10 = arith.constant 10 : index + %0 = cuf.alloc !fir.array<10xi32> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFscalar_to_arrayEa"} -> !fir.ref> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda, uniq_name = "_QFscalar_to_arrayEa"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + cuf.data_transfer %c1_i32 to %2#0 {transfer_kind = #cuf.cuda_transfer} : i32, !fir.ref> + cuf.free %2#1 : !fir.ref> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPscalar_to_array() +// CHECK: _FortranACUFDataTransferDescDescNoRealloc + } // end of module