diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 1f0576aa82f83..2ab2d84f1643d 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) { static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, cuf::DataTransferOp op, - const mlir::SymbolTable &symtab) { + const mlir::SymbolTable &symtab, + mlir::Type dstEleTy = nullptr) { auto mod = op->getParentOfType(); mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, mod); @@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, // from a LOGICAL constant. Store it as a fir.logical. srcTy = fir::LogicalType::get(rewriter.getContext(), 4); src = createConvertOp(rewriter, loc, srcTy, src); + addr = builder.createTemporary(loc, srcTy); + builder.create(loc, src, addr); + } else { + if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) { + // Use dstEleTy and convert to avoid assign mismatch. + addr = builder.createTemporary(loc, dstEleTy); + auto conv = builder.create(loc, dstEleTy, src); + builder.create(loc, conv, addr); + srcTy = dstEleTy; + } else { + // Put constant in memory if it is not. + addr = builder.createTemporary(loc, srcTy); + builder.create(loc, src, addr); + } } - // Put constant in memory if it is not. - mlir::Value alloc = builder.createTemporary(loc, srcTy); - builder.create(loc, src, alloc); - addr = alloc; } else { addr = op.getSrc(); } @@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion }; // Conversion of data transfer involving at least one descriptor. - if (mlir::isa(dstTy)) { + if (auto dstBoxTy = mlir::dyn_cast(dstTy)) { // Transfer to a descriptor. mlir::func::FuncOp func = isDstGlobal(op) @@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion mlir::Value dst = op.getDst(); mlir::Value src = op.getSrc(); if (!mlir::isa(srcTy)) { - src = emboxSrc(rewriter, op, symtab); + mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy()); + src = emboxSrc(rewriter, op, symtab, dstEleTy); if (fir::isa_trivial(srcTy)) func = fir::runtime::getRuntimeFunc( loc, builder); diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index b62c500f4a2d3..a724d9f681fb6 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box> {cuf.data_attr = #cuf // CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> () +func.func @_QPsub20() { + %0 = cuf.alloc !fir.box> {bindc_name = "r", data_attr = #cuf.cuda, uniq_name = "_QFsub20Er"} -> !fir.ref>> + %1 = fir.zero_bits !fir.heap + %2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap) -> !fir.box> + fir.store %2 to %0 : !fir.ref>> + %3:2 = hlfir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub20Er"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) + %c0_i32 = arith.constant 0 : i32 + cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer} : i32, !fir.ref>> + return +} + +// CHECK-LABEL:func.func @_QPsub20 +// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box +// CHECK: %[[TMP:.*]] = fir.alloca f32 +// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32 +// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref +// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref) -> !fir.box +// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref> +// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref>) -> !fir.ref> +// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref>, i32, !fir.ref, i32) -> () + } // end of module +