diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index d09f47a20b33d..9d911d6bfd406 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1725,13 +1725,17 @@ struct EmboxOpConversion : public EmboxCommonConversion { }; static bool isDeviceAllocation(mlir::Value val) { + if (auto loadOp = mlir::dyn_cast_or_null(val.getDefiningOp())) + return isDeviceAllocation(loadOp.getMemref()); if (auto convertOp = mlir::dyn_cast_or_null(val.getDefiningOp())) val = convertOp.getValue(); if (auto callOp = mlir::dyn_cast_or_null(val.getDefiningOp())) if (callOp.getCallee() && - callOp.getCallee().value().getRootReference().getValue().starts_with( - RTNAME_STRING(CUFMemAlloc))) + (callOp.getCallee().value().getRootReference().getValue().starts_with( + RTNAME_STRING(CUFMemAlloc)) || + callOp.getCallee().value().getRootReference().getValue().starts_with( + RTNAME_STRING(CUFAllocDesciptor)))) return true; return false; } @@ -2045,7 +2049,8 @@ struct XReboxOpConversion : public EmboxCommonConversion { } dest = insertBaseAddress(rewriter, loc, dest, base); mlir::Value result = - placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest); + placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest, + isDeviceAllocation(rebox.getBox())); rewriter.replaceOp(rebox, result); return mlir::success(); } diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir index a34c2770c5f6c..47c5667a14c95 100644 --- a/flang/test/Fir/CUDA/cuda-code-gen.mlir +++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir @@ -56,3 +56,73 @@ module attributes {dlti.dl_spec = #dlti.dl_spec : vector<2xi64> // CHECK-LABEL: llvm.func @_QQmain() // CHECK: llvm.call @_FortranACUFMemAlloc // CHECK: llvm.call @_FortranACUFAllocDesciptor + +// ----- + +module attributes {dlti.dl_spec = #dlti.dl_spec : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} { + func.func @_QQmain() attributes {fir.bindc_name = "p1"} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = fir.alloca i32 {bindc_name = "iblk", uniq_name = "_QFEiblk"} + %1 = fir.alloca i32 {bindc_name = "ithr", uniq_name = "_QFEithr"} + %2 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref> + %c14_i32 = arith.constant 14 : i32 + %c72 = arith.constant 72 : index + %3 = fir.convert %c72 : (index) -> i64 + %4 = fir.convert %2 : (!fir.ref>) -> !fir.ref + %5 = fir.call @_FortranACUFAllocDesciptor(%3, %4, %c14_i32) : (i64, !fir.ref, i32) -> !fir.ref> + %6 = fir.convert %5 : (!fir.ref>) -> !fir.ref>>> + %7 = fir.zero_bits !fir.heap> + %8 = fircg.ext_embox %7(%c0, %c0) {allocator_idx = 2 : i32} : (!fir.heap>, index, index) -> !fir.box>> + fir.store %8 to %6 : !fir.ref>>> + %9 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref> + %c20_i32 = arith.constant 20 : i32 + %c48 = arith.constant 48 : index + %10 = fir.convert %c48 : (index) -> i64 + %11 = fir.convert %9 : (!fir.ref>) -> !fir.ref + %12 = fir.call @_FortranACUFAllocDesciptor(%10, %11, %c20_i32) : (i64, !fir.ref, i32) -> !fir.ref> + %13 = fir.convert %12 : (!fir.ref>) -> !fir.ref>>> + %14 = fir.zero_bits !fir.heap> + %15 = fircg.ext_embox %14(%c0) {allocator_idx = 2 : i32} : (!fir.heap>, index) -> !fir.box>> + fir.store %15 to %13 : !fir.ref>>> + %16 = fir.convert %6 : (!fir.ref>>>) -> !fir.ref> + %17 = fir.convert %c1 : (index) -> i64 + %18 = fir.convert %c16_i32 : (i32) -> i64 + %19 = fir.call @_FortranAAllocatableSetBounds(%16, %c0_i32, %17, %18) fastmath : (!fir.ref>, i32, i64, i64) -> none + %20 = fir.call @_FortranAAllocatableSetBounds(%16, %c1_i32, %17, %18) fastmath : (!fir.ref>, i32, i64, i64) -> none + %21 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref> + %c31_i32 = arith.constant 31 : i32 + %false = arith.constant false + %22 = fir.absent !fir.box + %c-1_i64 = arith.constant -1 : i64 + %23 = fir.convert %6 : (!fir.ref>>>) -> !fir.ref> + %24 = fir.convert %21 : (!fir.ref>) -> !fir.ref + %25 = fir.call @_FortranACUFAllocatableAllocate(%23, %c-1_i64, %false, %22, %24, %c31_i32) : (!fir.ref>, i64, i1, !fir.box, !fir.ref, i32) -> i32 + %26 = fir.convert %13 : (!fir.ref>>>) -> !fir.ref> + %27 = fir.call @_FortranAAllocatableSetBounds(%26, %c0_i32, %17, %18) fastmath : (!fir.ref>, i32, i64, i64) -> none + %28 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref> + %c34_i32 = arith.constant 34 : i32 + %false_0 = arith.constant false + %29 = fir.absent !fir.box + %c-1_i64_1 = arith.constant -1 : i64 + %30 = fir.convert %13 : (!fir.ref>>>) -> !fir.ref> + %31 = fir.convert %28 : (!fir.ref>) -> !fir.ref + %32 = fir.call @_FortranACUFAllocatableAllocate(%30, %c-1_i64_1, %false_0, %29, %31, %c34_i32) : (!fir.ref>, i64, i1, !fir.box, !fir.ref, i32) -> i32 + %33 = fir.load %6 : !fir.ref>>> + %34 = fircg.ext_rebox %33 : (!fir.box>>) -> !fir.box> + return + } + func.func private @_FortranAAllocatableSetBounds(!fir.ref>, i32, i64, i64) -> none attributes {fir.runtime} + fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> { + %0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11> + fir.has_value %0 : !fir.char<1,11> + } + func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref, i32) -> !fir.ref> attributes {fir.runtime} + func.func private @_FortranACUFAllocatableAllocate(!fir.ref>, i64, i1, !fir.box, !fir.ref, i32) -> i32 attributes {fir.runtime} +} + +// CHECK-LABEL: llvm.func @_QQmain() +// CHECK-COUNT-4: llvm.call @_FortranACUFAllocDesciptor