diff --git a/flang/include/flang/Runtime/CUDA/common.h b/flang/include/flang/Runtime/CUDA/common.h index b73bc390ea8c9..4abccf5b341cf 100644 --- a/flang/include/flang/Runtime/CUDA/common.h +++ b/flang/include/flang/Runtime/CUDA/common.h @@ -9,6 +9,7 @@ #ifndef FORTRAN_RUNTIME_CUDA_COMMON_H_ #define FORTRAN_RUNTIME_CUDA_COMMON_H_ +#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Runtime/descriptor.h" #include "flang/Runtime/entry-names.h" @@ -34,4 +35,16 @@ static constexpr unsigned kDeviceToDevice = 2; terminator.Crash("'%s' failed with '%s'", #expr, name); \ }(expr) +static inline unsigned getMemType(cuf::DataAttribute attr) { + if (attr == cuf::DataAttribute::Device) + return kMemTypeDevice; + if (attr == cuf::DataAttribute::Managed) + return kMemTypeManaged; + if (attr == cuf::DataAttribute::Unified) + return kMemTypeUnified; + if (attr == cuf::DataAttribute::Pinned) + return kMemTypePinned; + llvm::report_fatal_error("unsupported memory type"); +} + #endif // FORTRAN_RUNTIME_CUDA_COMMON_H_ diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp index f8ace2dd96a0d..9c566c10daff5 100644 --- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp @@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) { return false; } +static int computeWidth(mlir::Location loc, mlir::Type type, + fir::KindMapping &kindMap) { + auto eleTy = fir::unwrapSequenceType(type); + int width = 0; + if (auto t{mlir::dyn_cast(eleTy)}) { + width = t.getWidth() / 8; + } else if (auto t{mlir::dyn_cast(eleTy)}) { + width = t.getWidth() / 8; + } else if (eleTy.isInteger(1)) { + width = 1; + } else if (auto t{mlir::dyn_cast(eleTy)}) { + int kind = t.getFKind(); + width = kindMap.getLogicalBitsize(kind) / 8; + } else if (auto t{mlir::dyn_cast(eleTy)}) { + int kind = t.getFKind(); + int elemSize = kindMap.getRealBitsize(kind) / 8; + width = 2 * elemSize; + } else { + llvm::report_fatal_error("unsupported type"); + } + return width; +} + struct CufAllocOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(cuf::AllocOp op, mlir::PatternRewriter &rewriter) const override { - auto boxTy = mlir::dyn_cast_or_null(op.getInType()); - - // Only convert cuf.alloc that allocates a descriptor. - if (!boxTy) - return failure(); if (inDeviceContext(op.getOperation())) { // In device context just replace the cuf.alloc operation with a fir.alloc @@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern { auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + + if (!mlir::dyn_cast_or_null(op.getInType())) { + // Convert scalar and known size array allocations. + mlir::Value bytes; + fir::KindMapping kindMap{fir::getKindMapping(mod)}; + if (fir::isa_trivial(op.getInType())) { + int width = computeWidth(loc, op.getInType(), kindMap); + bytes = + builder.createIntegerConstant(loc, builder.getIndexType(), width); + } else if (auto seqTy = mlir::dyn_cast_or_null( + op.getInType())) { + mlir::Value width = builder.createIntegerConstant( + loc, builder.getIndexType(), + computeWidth(loc, seqTy.getEleTy(), kindMap)); + mlir::Value nbElem; + if (fir::sequenceWithNonConstantShape(seqTy)) { + assert(!op.getShape().empty() && "expect shape with dynamic arrays"); + nbElem = builder.loadIfRef(loc, op.getShape()[0]); + for (unsigned i = 1; i < op.getShape().size(); ++i) { + nbElem = rewriter.create( + loc, nbElem, builder.loadIfRef(loc, op.getShape()[i])); + } + } else { + nbElem = builder.createIntegerConstant(loc, builder.getIndexType(), + seqTy.getConstantArraySize()); + } + bytes = rewriter.create(loc, nbElem, width); + } + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + mlir::Value memTy = builder.createIntegerConstant( + loc, builder.getI32Type(), getMemType(op.getDataAttr())); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)}; + auto callOp = builder.create(loc, func, args); + auto convOp = builder.createConvert(loc, op.getResult().getType(), + callOp.getResult(0)); + rewriter.replaceOp(op, convOp); + return mlir::success(); + } + + // Convert descriptor allocations to function call. + auto boxTy = mlir::dyn_cast_or_null(op.getInType()); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); @@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(cuf::FreeOp op, mlir::PatternRewriter &rewriter) const override { - // Only convert cuf.free on descriptor. - if (!mlir::isa(op.getDevptr().getType())) - return failure(); - auto refTy = mlir::dyn_cast(op.getDevptr().getType()); - if (!mlir::isa(refTy.getEleTy())) - return failure(); - if (inDeviceContext(op.getOperation())) { rewriter.eraseOp(op); return mlir::success(); } + if (!mlir::isa(op.getDevptr().getType())) + return failure(); + auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + + auto refTy = mlir::dyn_cast(op.getDevptr().getType()); + if (!mlir::isa(refTy.getEleTy())) { + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + mlir::Value memTy = builder.createIntegerConstant( + loc, builder.getI32Type(), getMemType(op.getDataAttr())); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)}; + builder.create(loc, func, args); + rewriter.eraseOp(op); + return mlir::success(); + } + + // Convert cuf.free on descriptors. mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); llvm::SmallVector args{fir::runtime::createArguments( @@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern { } }; -static int computeWidth(mlir::Location loc, mlir::Type type, - fir::KindMapping &kindMap) { - auto eleTy = fir::unwrapSequenceType(type); - int width = 0; - if (auto t{mlir::dyn_cast(eleTy)}) { - width = t.getWidth() / 8; - } else if (auto t{mlir::dyn_cast(eleTy)}) { - width = t.getWidth() / 8; - } else if (eleTy.isInteger(1)) { - width = 1; - } else if (auto t{mlir::dyn_cast(eleTy)}) { - int kind = t.getFKind(); - width = kindMap.getLogicalBitsize(kind) / 8; - } else if (auto t{mlir::dyn_cast(eleTy)}) { - int kind = t.getFKind(); - int elemSize = kindMap.getRealBitsize(kind) / 8; - width = 2 * elemSize; - } else { - llvm::report_fatal_error("unsupported type"); - } - return width; -} - static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { @@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase { fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false); fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, /*forceUnifiedTBAATree=*/false, *dl); - target.addDynamicallyLegalOp([](::cuf::AllocOp op) { - return !mlir::isa(op.getInType()); - }); - target.addDynamicallyLegalOp([](::cuf::FreeOp op) { - if (auto refTy = mlir::dyn_cast_or_null( - op.getDevptr().getType())) { - return !mlir::isa(refTy.getEleTy()); - } - return true; - }); target.addDynamicallyLegalOp( [](::cuf::DataTransferOp op) { mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir new file mode 100644 index 0000000000000..25821418a40f1 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir @@ -0,0 +1,64 @@ +// RUN: fir-opt --cuf-convert %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry, dense<64> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} { + +func.func @_QPsub1() { + %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda, uniq_name = "_QFsub1Eidev"} -> !fir.ref + %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda, uniq_name = "_QFsub1Eidev"} : (!fir.ref) -> (!fir.ref, !fir.ref) + cuf.free %1#1 : !fir.ref {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub1() +// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64 +// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr) -> !fir.ref +// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda, uniq_name = "_QFsub1Eidev"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, i32, !fir.ref, i32) -> none + +func.func @_QPsub2() { + %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref> + cuf.free %0 : !fir.ref> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub2() +// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index +// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64 +// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFMemFree + +func.func @_QPsub3(%arg0: !fir.ref {fir.bindc_name = "n"}, %arg1: !fir.ref {fir.bindc_name = "m"}) { + %0 = fir.dummy_scope : !fir.dscope + %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %3 = fir.load %1#0 : !fir.ref + %4 = fir.convert %3 : (i32) -> i64 + %5 = fir.convert %4 : (i64) -> index + %c0 = arith.constant 0 : index + %6 = arith.cmpi sgt, %5, %c0 : index + %7 = arith.select %6, %5, %c0 : index + %8 = fir.load %2#0 : !fir.ref + %9 = fir.convert %8 : (i32) -> i64 + %10 = fir.convert %9 : (i64) -> index + %c0_0 = arith.constant 0 : index + %11 = arith.cmpi sgt, %10, %c0_0 : index + %12 = arith.select %11, %10, %c0_0 : index + %13 = cuf.alloc !fir.array, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda, uniq_name = "_QFsub3Eidev"} -> !fir.ref> + %14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2> + %15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda, uniq_name = "_QFsub3Eidev"} : (!fir.ref>, !fir.shape<2>) -> (!fir.box>, !fir.ref>) + cuf.free %15#1 : !fir.ref> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub3 +// CHECK: %[[N:.*]] = arith.select +// CHECK: %[[M:.*]] = arith.select +// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index +// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index +// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64 +// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFMemFree + +} // end module diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir index 65c68bb69301a..d68ff894d5af5 100644 --- a/flang/test/Fir/CUDA/cuda-allocate.fir +++ b/flang/test/Fir/CUDA/cuda-allocate.fir @@ -26,17 +26,6 @@ func.func @_QPsub1() { // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref>>>) -> !fir.ref> // CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.ref, i32) -> none -// Check operations that should not be transformed yet. -func.func @_QPsub2() { - %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref> - cuf.free %0 : !fir.ref> {data_attr = #cuf.cuda} - return -} - -// CHECK-LABEL: func.func @_QPsub2() -// CHECK: cuf.alloc !fir.array<10xf32> -// CHECK: cuf.free %{{.*}} : !fir.ref> - fir.global @_QMmod1Ea {data_attr = #cuf.cuda} : !fir.box>> { %0 = fir.zero_bits !fir.heap> %c0 = arith.constant 0 : index