Skip to content

Commit 01a87f2

Browse files
authored
[flang][cuda] Make sure dstEleTy is set when used in CUFOpConversion (llvm#163240)
When the src is an i1, we use the dst element type. In some case, the dst element type was null. Make sure we pass one to `emboxSrc` and add an assertion when we use it to catch it in case it is null.
1 parent 13e563e commit 01a87f2

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
558558
if (srcTy.isInteger(1)) {
559559
// i1 is not a supported type in the descriptor and it is actually coming
560560
// from a LOGICAL constant. Use the destination type to avoid mismatch.
561+
assert(dstEleTy && "expect dst element type to be set");
561562
srcTy = dstEleTy;
562563
src = createConvertOp(rewriter, loc, srcTy, src);
563564
addr = builder.createTemporary(loc, srcTy);
@@ -652,7 +653,8 @@ struct CUFDataTransferOpConversion
652653
// Initialization of an array from a scalar value should be implemented
653654
// via a kernel launch. Use the flang runtime via the Assign function
654655
// until we have more infrastructure.
655-
mlir::Value src = emboxSrc(rewriter, op, symtab);
656+
mlir::Type dstEleTy = fir::unwrapInnerType(fir::unwrapRefType(dstTy));
657+
mlir::Value src = emboxSrc(rewriter, op, symtab, dstEleTy);
656658
mlir::Value dst = emboxDst(rewriter, op, symtab);
657659
mlir::func::FuncOp func =
658660
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,5 +651,45 @@ func.func @_QPsub28() {
651651
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<8>>>) -> !fir.ref<!fir.box<none>>
652652
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
653653

654+
func.func @_QPtesti4(%arg0: !fir.ref<i32> {fir.bindc_name = "n1"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n2"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n3"}, %arg3: !fir.ref<i32> {fir.bindc_name = "n4"}) {
655+
%true = arith.constant true
656+
%c0 = arith.constant 0 : index
657+
%c2_i32 = arith.constant 2 : i32
658+
%0 = fir.dummy_scope : !fir.dscope
659+
%1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFtesti4En1"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
660+
%2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFtesti4En2"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
661+
%3:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFtesti4En3"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
662+
%4:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFtesti4En4"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
663+
%5 = fir.load %1#0 : !fir.ref<i32>
664+
%6 = arith.divsi %5, %c2_i32 : i32
665+
%7 = fir.convert %6 : (i32) -> index
666+
%8 = arith.cmpi sgt, %7, %c0 : index
667+
%9 = arith.select %8, %7, %c0 : index
668+
%10 = fir.load %2#0 : !fir.ref<i32>
669+
%11 = arith.divsi %10, %c2_i32 : i32
670+
%12 = fir.convert %11 : (i32) -> index
671+
%13 = arith.cmpi sgt, %12, %c0 : index
672+
%14 = arith.select %13, %12, %c0 : index
673+
%15 = fir.load %3#0 : !fir.ref<i32>
674+
%16 = arith.divsi %15, %c2_i32 : i32
675+
%17 = fir.convert %16 : (i32) -> index
676+
%18 = arith.cmpi sgt, %17, %c0 : index
677+
%19 = arith.select %18, %17, %c0 : index
678+
%20 = fir.load %4#0 : !fir.ref<i32>
679+
%21 = arith.divsi %20, %c2_i32 : i32
680+
%22 = fir.convert %21 : (i32) -> index
681+
%23 = arith.cmpi sgt, %22, %c0 : index
682+
%24 = arith.select %23, %22, %c0 : index
683+
%25 = cuf.alloc !fir.array<?x?x?x?x!fir.logical<4>>, %9, %14, %19, %24 : index, index, index, index {bindc_name = "lma", data_attr = #cuf.cuda<managed>, uniq_name = "_QFtesti4Elma"} -> !fir.ref<!fir.array<?x?x?x?x!fir.logical<4>>>
684+
%26 = fir.shape %9, %14, %19, %24 : (index, index, index, index) -> !fir.shape<4>
685+
%27:2 = hlfir.declare %25(%26) {data_attr = #cuf.cuda<managed>, uniq_name = "_QFtesti4Elma"} : (!fir.ref<!fir.array<?x?x?x?x!fir.logical<4>>>, !fir.shape<4>) -> (!fir.box<!fir.array<?x?x?x?x!fir.logical<4>>>, !fir.ref<!fir.array<?x?x?x?x!fir.logical<4>>>)
686+
cuf.data_transfer %true to %27#1, %26 : !fir.shape<4> {transfer_kind = #cuf.cuda_transfer<host_device>} : i1, !fir.ref<!fir.array<?x?x?x?x!fir.logical<4>>>
687+
cuf.free %27#1 : !fir.ref<!fir.array<?x?x?x?x!fir.logical<4>>> {data_attr = #cuf.cuda<managed>}
688+
return
689+
}
690+
691+
// CHECK-LABEL: func.func @_QPtesti4
692+
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
693+
654694
} // end of module
655695

0 commit comments

Comments
 (0)