4141#include " mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
4242#include " mlir/Dialect/Arith/IR/Arith.h"
4343#include " mlir/Dialect/DLTI/DLTI.h"
44+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
4445#include " mlir/Dialect/LLVMIR/LLVMAttrs.h"
4546#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
4647#include " mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
920921};
921922} // namespace
922923
923- // / Return the LLVMFuncOp corresponding to the standard malloc call.
924+ template < typename ModuleOp>
924925static mlir::SymbolRefAttr
925- getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
926+ getMallocInModule (ModuleOp mod, fir::AllocMemOp op,
927+ mlir::ConversionPatternRewriter &rewriter) {
926928 static constexpr char mallocName[] = " malloc" ;
927- auto module = op-> getParentOfType <mlir::ModuleOp>();
928- if ( auto mallocFunc = module . lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
929+ if ( auto mallocFunc =
930+ mod. template lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
929931 return mlir::SymbolRefAttr::get (mallocFunc);
930- if (auto userMalloc = module .lookupSymbol <mlir::func::FuncOp>(mallocName))
932+ if (auto userMalloc =
933+ mod.template lookupSymbol <mlir::func::FuncOp>(mallocName))
931934 return mlir::SymbolRefAttr::get (userMalloc);
932- mlir::OpBuilder moduleBuilder (
933- op-> getParentOfType < mlir::ModuleOp>() .getBodyRegion ());
935+
936+ mlir::OpBuilder moduleBuilder (mod .getBodyRegion ());
934937 auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
935938 auto mallocDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
936939 op.getLoc (), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
940943 return mlir::SymbolRefAttr::get (mallocDecl);
941944}
942945
946+ // / Return the LLVMFuncOp corresponding to the standard malloc call.
947+ static mlir::SymbolRefAttr
948+ getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
949+ if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
950+ return getMallocInModule (mod, op, rewriter);
951+ auto mod = op->getParentOfType <mlir::ModuleOp>();
952+ return getMallocInModule (mod, op, rewriter);
953+ }
954+
943955// / Helper function for generating the LLVM IR that computes the distance
944956// / in bytes between adjacent elements pointed to by a pointer
945957// / of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10161028} // namespace
10171029
10181030// / Return the LLVMFuncOp corresponding to the standard free call.
1019- static mlir::SymbolRefAttr getFree (fir::FreeMemOp op,
1020- mlir::ConversionPatternRewriter &rewriter) {
1031+ template <typename ModuleOp>
1032+ static mlir::SymbolRefAttr
1033+ getFreeInModule (ModuleOp mod, fir::FreeMemOp op,
1034+ mlir::ConversionPatternRewriter &rewriter) {
10211035 static constexpr char freeName[] = " free" ;
1022- auto module = op->getParentOfType <mlir::ModuleOp>();
10231036 // Check if free already defined in the module.
1024- if (auto freeFunc = module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(freeName))
1037+ if (auto freeFunc =
1038+ mod.template lookupSymbol <mlir::LLVM::LLVMFuncOp>(freeName))
10251039 return mlir::SymbolRefAttr::get (freeFunc);
10261040 if (auto freeDefinedByUser =
1027- module . lookupSymbol <mlir::func::FuncOp>(freeName))
1041+ mod. template lookupSymbol <mlir::func::FuncOp>(freeName))
10281042 return mlir::SymbolRefAttr::get (freeDefinedByUser);
10291043 // Create llvm declaration for free.
1030- mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1044+ mlir::OpBuilder moduleBuilder (mod .getBodyRegion ());
10311045 auto voidType = mlir::LLVM::LLVMVoidType::get (op.getContext ());
10321046 auto freeDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
10331047 rewriter.getUnknownLoc (), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
10371051 return mlir::SymbolRefAttr::get (freeDecl);
10381052}
10391053
1054+ static mlir::SymbolRefAttr getFree (fir::FreeMemOp op,
1055+ mlir::ConversionPatternRewriter &rewriter) {
1056+ if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
1057+ return getFreeInModule (mod, op, rewriter);
1058+ auto mod = op->getParentOfType <mlir::ModuleOp>();
1059+ return getFreeInModule (mod, op, rewriter);
1060+ }
1061+
10401062static unsigned getDimension (mlir::LLVM::LLVMArrayType ty) {
10411063 unsigned result = 1 ;
10421064 for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
37303752 mlir::configureOpenMPToLLVMConversionLegality (target, typeConverter);
37313753 target.addLegalDialect <mlir::omp::OpenMPDialect>();
37323754 target.addLegalDialect <mlir::acc::OpenACCDialect>();
3755+ target.addLegalDialect <mlir::gpu::GPUDialect>();
37333756
37343757 // required NOPs for applying a full conversion
37353758 target.addLegalOp <mlir::ModuleOp>();
0 commit comments