From f436c3fbd7cf8d23902e5f0ac96ecd3ebabdb470 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Fri, 8 Nov 2024 14:02:14 -0800 Subject: [PATCH] [flang][cuda] Support derived type in cuf.alloc --- flang/lib/Optimizer/Transforms/CUFOpConversion.cpp | 7 +++++++ flang/test/Fir/CUDA/cuda-alloc-free.fir | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 881f54133ce73..8e9de3d328152 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -337,6 +337,13 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern { seqTy.getConstantArraySize()); } bytes = rewriter.create(loc, nbElem, width); + } else if (fir::isa_derived(op.getInType())) { + mlir::Type structTy = typeConverter->convertType(op.getInType()); + std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8; + bytes = builder.createIntegerConstant(loc, builder.getIndexType(), + structSize); + } else { + mlir::emitError(loc, "unsupported type in cuf.alloc\n"); } mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir index 25821418a40f1..88b1a00e4a5b2 100644 --- a/flang/test/Fir/CUDA/cuda-alloc-free.fir +++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir @@ -61,4 +61,16 @@ func.func @_QPsub3(%arg0: !fir.ref {fir.bindc_name = "n"}, %arg1: !fir.ref< // CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFMemFree +func.func @_QPtest_type() { + %0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFtest_typeEa"} -> !fir.ref}>> + %1 = fir.declare %0 {data_attr = #cuf.cuda, uniq_name = "_QFtest_typeEa"} : (!fir.ref}>>) -> !fir.ref}>> + cuf.free %1 : !fir.ref}>> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPtest_type() +// CHECK: %[[BYTES:.*]] = arith.constant 12 : index +// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64 +// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr + } // end module