Skip to content

Commit 701b839

Browse files
authored
[flang][cuda] Fix type mismatch when transferring logical (llvm#157952)
1 parent 8ae3aea commit 701b839

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
557557
mlir::Value src = op.getSrc();
558558
if (srcTy.isInteger(1)) {
559559
// i1 is not a supported type in the descriptor and it is actually coming
560-
// from a LOGICAL constant. Store it as a fir.logical.
561-
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
560+
// from a LOGICAL constant. Use the destination type to avoid mismatch.
561+
srcTy = dstEleTy;
562562
src = createConvertOp(rewriter, loc, srcTy, src);
563563
addr = builder.createTemporary(loc, srcTy);
564564
fir::StoreOp::create(builder, loc, src, addr);
@@ -650,7 +650,7 @@ struct CUFDataTransferOpConversion
650650

651651
if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
652652
// Initialization of an array from a scalar value should be implemented
653-
// via a kernel launch. Use the flan runtime via the Assign function
653+
// via a kernel launch. Use the flang runtime via the Assign function
654654
// until we have more infrastructure.
655655
mlir::Value src = emboxSrc(rewriter, op, symtab);
656656
mlir::Value dst = emboxDst(rewriter, op, symtab);

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,13 @@ func.func @_QPlogical_cst() {
463463
}
464464

465465
// CHECK-LABEL: func.func @_QPlogical_cst()
466-
// CHECK: %[[DESC:.*]] = fir.alloca !fir.box<!fir.logical<4>>
467-
// CHECK: %[[CONST:.*]] = fir.alloca !fir.logical<4>
468-
// CHECK: %[[CONV:.*]] = fir.convert %false : (i1) -> !fir.logical<4>
469-
// CHECK: fir.store %[[CONV]] to %[[CONST]] : !fir.ref<!fir.logical<4>>
470-
// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
471-
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<4>>>
472-
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<4>>>) -> !fir.ref<!fir.box<none>>
466+
// CHECK: %[[DESC:.*]] = fir.alloca !fir.box<!fir.logical<1>>
467+
// CHECK: %[[CONST:.*]] = fir.alloca !fir.logical<1>
468+
// CHECK: %[[CONV:.*]] = fir.convert %false : (i1) -> !fir.logical<1>
469+
// CHECK: fir.store %[[CONV]] to %[[CONST]] : !fir.ref<!fir.logical<1>>
470+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<1>>) -> !fir.box<!fir.logical<1>>
471+
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<1>>>
472+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<1>>>) -> !fir.ref<!fir.box<none>>
473473
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
474474

475475
func.func @_QPcallkernel(%arg0: !fir.box<!fir.array<?x?xcomplex<f32>>> {fir.bindc_name = "a"}, %arg1: !fir.ref<f32> {fir.bindc_name = "b"}, %arg2: !fir.ref<f32> {fir.bindc_name = "c"}) {
@@ -603,5 +603,53 @@ func.func @_QPsub20() {
603603
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
604604
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
605605

606+
func.func @_QPsub28() {
607+
%0 = fir.dummy_scope : !fir.dscope
608+
%1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>> {bindc_name = "id2", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub28Eid2"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>
609+
%2 = fir.zero_bits !fir.heap<!fir.array<?x?x!fir.logical<8>>>
610+
%c0 = arith.constant 0 : index
611+
%3 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2>
612+
%4 = fir.embox %2(%3) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?x?x!fir.logical<8>>>, !fir.shape<2>) -> !fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>
613+
fir.store %4 to %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>
614+
%5:2 = hlfir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub28Eid2"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>)
615+
%c1 = arith.constant 1 : index
616+
%c10_i32 = arith.constant 10 : i32
617+
%c0_i32 = arith.constant 0 : i32
618+
%6 = fir.convert %5#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>) -> !fir.ref<!fir.box<none>>
619+
%7 = fir.convert %c1 : (index) -> i64
620+
%8 = fir.convert %c10_i32 : (i32) -> i64
621+
fir.call @_FortranAAllocatableSetBounds(%6, %c0_i32, %7, %8) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
622+
%c1_0 = arith.constant 1 : index
623+
%c10_i32_1 = arith.constant 10 : i32
624+
%c1_i32 = arith.constant 1 : i32
625+
%9 = fir.convert %5#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>) -> !fir.ref<!fir.box<none>>
626+
%10 = fir.convert %c1_0 : (index) -> i64
627+
%11 = fir.convert %c10_i32_1 : (i32) -> i64
628+
fir.call @_FortranAAllocatableSetBounds(%9, %c1_i32, %10, %11) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
629+
%12 = cuf.allocate %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>> {data_attr = #cuf.cuda<device>} -> i32
630+
%false = arith.constant false
631+
cuf.data_transfer %false to %5#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i1, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>
632+
%13 = fir.load %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>>
633+
%14 = fir.box_addr %13 : (!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>) -> !fir.heap<!fir.array<?x?x!fir.logical<8>>>
634+
%15 = fir.convert %14 : (!fir.heap<!fir.array<?x?x!fir.logical<8>>>) -> i64
635+
%c0_i64 = arith.constant 0 : i64
636+
%16 = arith.cmpi ne, %15, %c0_i64 : i64
637+
fir.if %16 {
638+
%17 = cuf.deallocate %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>> {data_attr = #cuf.cuda<device>} -> i32
639+
}
640+
cuf.free %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?x!fir.logical<8>>>>> {data_attr = #cuf.cuda<device>}
641+
return
642+
}
643+
644+
// CHECK-LABEL: func.func @_QPsub28()
645+
// CHECK: %[[DESC:.*]] = fir.alloca !fir.box<!fir.logical<8>>
646+
// CHECK: %[[L8:.*]] = fir.alloca !fir.logical<8>
647+
// CHECK: %[[FALSE:.*]] = fir.convert %false{{.*}} : (i1) -> !fir.logical<8>
648+
// CHECK: fir.store %[[FALSE]] to %[[L8]] : !fir.ref<!fir.logical<8>>
649+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[L8]] : (!fir.ref<!fir.logical<8>>) -> !fir.box<!fir.logical<8>>
650+
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<8>>>
651+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<8>>>) -> !fir.ref<!fir.box<none>>
652+
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
653+
606654
} // end of module
607655

0 commit comments

Comments
 (0)