Skip to content

Commit 20c0304

Browse files
committed
[flang][cuda] Avoid assign element mismatch when doing data trabsfer from a constant
1 parent 7e4fef6 commit 20c0304

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
541541

542542
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
543543
cuf::DataTransferOp op,
544-
const mlir::SymbolTable &symtab) {
544+
const mlir::SymbolTable &symtab,
545+
mlir::Type dstEleTy = nullptr) {
545546
auto mod = op->getParentOfType<mlir::ModuleOp>();
546547
mlir::Location loc = op.getLoc();
547548
fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
555556
// from a LOGICAL constant. Store it as a fir.logical.
556557
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
557558
src = createConvertOp(rewriter, loc, srcTy, src);
559+
addr = builder.createTemporary(loc, srcTy);
560+
builder.create<fir::StoreOp>(loc, src, addr);
561+
} else {
562+
if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
563+
// Use dstEleTy and convert to avoid assign mismatch.
564+
addr = builder.createTemporary(loc, dstEleTy);
565+
auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
566+
builder.create<fir::StoreOp>(loc, conv, addr);
567+
srcTy = dstEleTy;
568+
} else {
569+
// Put constant in memory if it is not.
570+
addr = builder.createTemporary(loc, srcTy);
571+
builder.create<fir::StoreOp>(loc, src, addr);
572+
}
558573
}
559-
// Put constant in memory if it is not.
560-
mlir::Value alloc = builder.createTemporary(loc, srcTy);
561-
builder.create<fir::StoreOp>(loc, src, alloc);
562-
addr = alloc;
563574
} else {
564575
addr = op.getSrc();
565576
}
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
729740
};
730741

731742
// Conversion of data transfer involving at least one descriptor.
732-
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
743+
if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
733744
// Transfer to a descriptor.
734745
mlir::func::FuncOp func =
735746
isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
740751
mlir::Value dst = op.getDst();
741752
mlir::Value src = op.getSrc();
742753
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
743-
src = emboxSrc(rewriter, op, symtab);
754+
mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
755+
src = emboxSrc(rewriter, op, symtab, dstEleTy);
744756
if (fir::isa_trivial(srcTy))
745757
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
746758
loc, builder);

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
582582
// CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
583583
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
584584

585+
func.func @_QPsub20() {
586+
%0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
587+
%1 = fir.zero_bits !fir.heap<f32>
588+
%2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
589+
fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
590+
%3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
591+
%c0_i32 = arith.constant 0 : i32
592+
cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
593+
return
594+
}
595+
596+
// CHECK-LABEL:func.func @_QPsub20
597+
// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
598+
// CHECK: %[[TMP:.*]] = fir.alloca f32
599+
// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
600+
// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
601+
// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
602+
// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
603+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
604+
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
605+
585606
} // end of module
607+

0 commit comments

Comments
 (0)