88
99#include " flang/Common/Fortran.h"
1010#include " flang/Optimizer/Builder/Runtime/RTBuilder.h"
11+ #include " flang/Optimizer/CodeGen/TypeConverter.h"
1112#include " flang/Optimizer/Dialect/CUF/CUFOps.h"
1213#include " flang/Optimizer/Dialect/FIRDialect.h"
1314#include " flang/Optimizer/Dialect/FIROps.h"
1415#include " flang/Optimizer/HLFIR/HLFIROps.h"
16+ #include " flang/Optimizer/Support/DataLayout.h"
17+ #include " flang/Runtime/CUDA/descriptor.h"
1518#include " flang/Runtime/allocatable.h"
1619#include " mlir/Pass/Pass.h"
1720#include " mlir/Transforms/DialectConversion.h"
@@ -25,6 +28,7 @@ namespace fir {
2528using namespace fir ;
2629using namespace mlir ;
2730using namespace Fortran ::runtime;
31+ using namespace Fortran ::runtime::cuf;
2832
2933namespace {
3034
@@ -75,11 +79,11 @@ static mlir::LogicalResult convertOpToCall(OpTy op,
7579}
7680
7781struct CufAllocateOpConversion
78- : public mlir::OpRewritePattern<cuf::AllocateOp> {
82+ : public mlir::OpRewritePattern<:: cuf::AllocateOp> {
7983 using OpRewritePattern::OpRewritePattern;
8084
8185 mlir::LogicalResult
82- matchAndRewrite (cuf::AllocateOp op,
86+ matchAndRewrite (:: cuf::AllocateOp op,
8387 mlir::PatternRewriter &rewriter) const override {
8488 // TODO: Allocation with source will need a new entry point in the runtime.
8589 if (op.getSource ())
@@ -108,16 +112,16 @@ struct CufAllocateOpConversion
108112 mlir::func::FuncOp func =
109113 fir::runtime::getRuntimeFunc<mkRTKey (AllocatableAllocate)>(loc,
110114 builder);
111- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
115+ return convertOpToCall<:: cuf::AllocateOp>(op, rewriter, func);
112116 }
113117};
114118
115119struct CufDeallocateOpConversion
116- : public mlir::OpRewritePattern<cuf::DeallocateOp> {
120+ : public mlir::OpRewritePattern<:: cuf::DeallocateOp> {
117121 using OpRewritePattern::OpRewritePattern;
118122
119123 mlir::LogicalResult
120- matchAndRewrite (cuf::DeallocateOp op,
124+ matchAndRewrite (:: cuf::DeallocateOp op,
121125 mlir::PatternRewriter &rewriter) const override {
122126 // TODO: Allocation of module variable will need more work as the descriptor
123127 // will be duplicated and needs to be synced after allocation.
@@ -133,7 +137,84 @@ struct CufDeallocateOpConversion
133137 mlir::func::FuncOp func =
134138 fir::runtime::getRuntimeFunc<mkRTKey (AllocatableDeallocate)>(loc,
135139 builder);
136- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
140+ return convertOpToCall<::cuf::DeallocateOp>(op, rewriter, func);
141+ }
142+ };
143+
144+ struct CufAllocOpConversion : public mlir ::OpRewritePattern<::cuf::AllocOp> {
145+ using OpRewritePattern::OpRewritePattern;
146+
147+ CufAllocOpConversion (mlir::MLIRContext *context, mlir::DataLayout *dl,
148+ fir::LLVMTypeConverter *typeConverter)
149+ : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
150+
151+ mlir::LogicalResult
152+ matchAndRewrite (::cuf::AllocOp op,
153+ mlir::PatternRewriter &rewriter) const override {
154+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
155+
156+ // Only convert cuf.alloc that allocates a descriptor.
157+ if (!boxTy)
158+ return failure ();
159+
160+ auto mod = op->getParentOfType <mlir::ModuleOp>();
161+ fir::FirOpBuilder builder (rewriter, mod);
162+ mlir::Location loc = op.getLoc ();
163+ mlir::func::FuncOp func =
164+ fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocDesciptor)>(loc, builder);
165+
166+ auto fTy = func.getFunctionType ();
167+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
168+ mlir::Value sourceLine =
169+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
170+
171+ mlir::Type structTy = typeConverter->convertBoxTypeAsStruct (boxTy);
172+ std::size_t boxSize = dl->getTypeSizeInBits (structTy) / 8 ;
173+ mlir::Value sizeInBytes =
174+ builder.createIntegerConstant (loc, builder.getIndexType (), boxSize);
175+
176+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
177+ builder, loc, fTy , sizeInBytes, sourceFile, sourceLine)};
178+ auto callOp = builder.create <fir::CallOp>(loc, func, args);
179+ auto convOp = builder.createConvert (loc, op.getResult ().getType (),
180+ callOp.getResult (0 ));
181+ rewriter.replaceOp (op, convOp);
182+ return mlir::success ();
183+ }
184+
185+ private:
186+ mlir::DataLayout *dl;
187+ fir::LLVMTypeConverter *typeConverter;
188+ };
189+
190+ struct CufFreeOpConversion : public mlir ::OpRewritePattern<::cuf::FreeOp> {
191+ using OpRewritePattern::OpRewritePattern;
192+
193+ mlir::LogicalResult
194+ matchAndRewrite (::cuf::FreeOp op,
195+ mlir::PatternRewriter &rewriter) const override {
196+ // Only convert cuf.free on descriptor.
197+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
198+ return failure ();
199+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
200+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
201+ return failure ();
202+
203+ auto mod = op->getParentOfType <mlir::ModuleOp>();
204+ fir::FirOpBuilder builder (rewriter, mod);
205+ mlir::Location loc = op.getLoc ();
206+ mlir::func::FuncOp func =
207+ fir::runtime::getRuntimeFunc<mkRTKey (CUFFreeDesciptor)>(loc, builder);
208+
209+ auto fTy = func.getFunctionType ();
210+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
211+ mlir::Value sourceLine =
212+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
213+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
214+ builder, loc, fTy , op.getDevptr (), sourceFile, sourceLine)};
215+ builder.create <fir::CallOp>(loc, func, args);
216+ rewriter.eraseOp (op);
217+ return mlir::success ();
137218 }
138219};
139220
@@ -143,8 +224,22 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
143224 auto *ctx = &getContext ();
144225 mlir::RewritePatternSet patterns (ctx);
145226 mlir::ConversionTarget target (*ctx);
146- target.addIllegalOp <cuf::AllocateOp, cuf::DeallocateOp>();
147- patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion>(ctx);
227+
228+ mlir::Operation *op = getOperation ();
229+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
230+ if (!module )
231+ return signalPassFailure ();
232+
233+ std::optional<mlir::DataLayout> dl =
234+ fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
235+ fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
236+ /* forceUnifiedTBAATree=*/ false , *dl);
237+
238+ target.addIllegalOp <::cuf::AllocOp, ::cuf::AllocateOp, ::cuf::DeallocateOp,
239+ ::cuf::FreeOp>();
240+ patterns.insert <CufAllocOpConversion>(ctx, &*dl, &typeConverter);
241+ patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
242+ CufFreeOpConversion>(ctx);
148243 if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
149244 std::move (patterns)))) {
150245 mlir::emitError (mlir::UnknownLoc::get (ctx),
0 commit comments