@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
183183 return false ;
184184}
185185
186+ static int computeWidth (mlir::Location loc, mlir::Type type,
187+ fir::KindMapping &kindMap) {
188+ auto eleTy = fir::unwrapSequenceType (type);
189+ int width = 0 ;
190+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
191+ width = t.getWidth () / 8 ;
192+ } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
193+ width = t.getWidth () / 8 ;
194+ } else if (eleTy.isInteger (1 )) {
195+ width = 1 ;
196+ } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
197+ int kind = t.getFKind ();
198+ width = kindMap.getLogicalBitsize (kind) / 8 ;
199+ } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
200+ int kind = t.getFKind ();
201+ int elemSize = kindMap.getRealBitsize (kind) / 8 ;
202+ width = 2 * elemSize;
203+ } else {
204+ llvm::report_fatal_error (" unsupported type" );
205+ }
206+ return width;
207+ }
208+
186209struct CufAllocOpConversion : public mlir ::OpRewritePattern<cuf::AllocOp> {
187210 using OpRewritePattern::OpRewritePattern;
188211
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
193216 mlir::LogicalResult
194217 matchAndRewrite (cuf::AllocOp op,
195218 mlir::PatternRewriter &rewriter) const override {
196- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
197-
198- // Only convert cuf.alloc that allocates a descriptor.
199- if (!boxTy)
200- return failure ();
201219
202220 if (inDeviceContext (op.getOperation ())) {
203221 // In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
212230 auto mod = op->getParentOfType <mlir::ModuleOp>();
213231 fir::FirOpBuilder builder (rewriter, mod);
214232 mlir::Location loc = op.getLoc ();
233+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
234+
235+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ())) {
236+ // Convert scalar and known size array allocations.
237+ mlir::Value bytes;
238+ fir::KindMapping kindMap{fir::getKindMapping (mod)};
239+ if (fir::isa_trivial (op.getInType ())) {
240+ int width = computeWidth (loc, op.getInType (), kindMap);
241+ bytes =
242+ builder.createIntegerConstant (loc, builder.getIndexType (), width);
243+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
244+ op.getInType ())) {
245+ mlir::Value width = builder.createIntegerConstant (
246+ loc, builder.getIndexType (),
247+ computeWidth (loc, seqTy.getEleTy (), kindMap));
248+ mlir::Value nbElem;
249+ if (fir::sequenceWithNonConstantShape (seqTy)) {
250+ assert (!op.getShape ().empty () && " expect shape with dynamic arrays" );
251+ nbElem = builder.loadIfRef (loc, op.getShape ()[0 ]);
252+ for (unsigned i = 1 ; i < op.getShape ().size (); ++i) {
253+ nbElem = rewriter.create <mlir::arith::MulIOp>(
254+ loc, nbElem, builder.loadIfRef (loc, op.getShape ()[i]));
255+ }
256+ } else {
257+ nbElem = builder.createIntegerConstant (loc, builder.getIndexType (),
258+ seqTy.getConstantArraySize ());
259+ }
260+ bytes = rewriter.create <mlir::arith::MulIOp>(loc, nbElem, width);
261+ }
262+ mlir::func::FuncOp func =
263+ fir::runtime::getRuntimeFunc<mkRTKey (CUFMemAlloc)>(loc, builder);
264+ auto fTy = func.getFunctionType ();
265+ mlir::Value sourceLine =
266+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (3 ));
267+ mlir::Value memTy = builder.createIntegerConstant (
268+ loc, builder.getI32Type (), getMemType (op.getDataAttr ()));
269+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
270+ builder, loc, fTy , bytes, memTy, sourceFile, sourceLine)};
271+ auto callOp = builder.create <fir::CallOp>(loc, func, args);
272+ auto convOp = builder.createConvert (loc, op.getResult ().getType (),
273+ callOp.getResult (0 ));
274+ rewriter.replaceOp (op, convOp);
275+ return mlir::success ();
276+ }
277+
278+ // Convert descriptor allocations to function call.
279+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType ());
215280 mlir::func::FuncOp func =
216281 fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocDesciptor)>(loc, builder);
217-
218282 auto fTy = func.getFunctionType ();
219- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
220283 mlir::Value sourceLine =
221284 fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
222285
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
245308 mlir::LogicalResult
246309 matchAndRewrite (cuf::FreeOp op,
247310 mlir::PatternRewriter &rewriter) const override {
248- // Only convert cuf.free on descriptor.
249- if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
250- return failure ();
251- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
252- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
253- return failure ();
254-
255311 if (inDeviceContext (op.getOperation ())) {
256312 rewriter.eraseOp (op);
257313 return mlir::success ();
258314 }
259315
316+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr ().getType ()))
317+ return failure ();
318+
260319 auto mod = op->getParentOfType <mlir::ModuleOp>();
261320 fir::FirOpBuilder builder (rewriter, mod);
262321 mlir::Location loc = op.getLoc ();
322+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
323+
324+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr ().getType ());
325+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ())) {
326+ mlir::func::FuncOp func =
327+ fir::runtime::getRuntimeFunc<mkRTKey (CUFMemFree)>(loc, builder);
328+ auto fTy = func.getFunctionType ();
329+ mlir::Value sourceLine =
330+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (3 ));
331+ mlir::Value memTy = builder.createIntegerConstant (
332+ loc, builder.getI32Type (), getMemType (op.getDataAttr ()));
333+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
334+ builder, loc, fTy , op.getDevptr (), memTy, sourceFile, sourceLine)};
335+ builder.create <fir::CallOp>(loc, func, args);
336+ rewriter.eraseOp (op);
337+ return mlir::success ();
338+ }
339+
340+ // Convert cuf.free on descriptors.
263341 mlir::func::FuncOp func =
264342 fir::runtime::getRuntimeFunc<mkRTKey (CUFFreeDesciptor)>(loc, builder);
265-
266343 auto fTy = func.getFunctionType ();
267- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
268344 mlir::Value sourceLine =
269345 fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
270346 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
275351 }
276352};
277353
278- static int computeWidth (mlir::Location loc, mlir::Type type,
279- fir::KindMapping &kindMap) {
280- auto eleTy = fir::unwrapSequenceType (type);
281- int width = 0 ;
282- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
283- width = t.getWidth () / 8 ;
284- } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
285- width = t.getWidth () / 8 ;
286- } else if (eleTy.isInteger (1 )) {
287- width = 1 ;
288- } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
289- int kind = t.getFKind ();
290- width = kindMap.getLogicalBitsize (kind) / 8 ;
291- } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
292- int kind = t.getFKind ();
293- int elemSize = kindMap.getRealBitsize (kind) / 8 ;
294- width = 2 * elemSize;
295- } else {
296- llvm::report_fatal_error (" unsupported type" );
297- }
298- return width;
299- }
300-
301354static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
302355 mlir::Location loc, mlir::Type toTy,
303356 mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
456509 fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
457510 fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
458511 /* forceUnifiedTBAATree=*/ false , *dl);
459- target.addDynamicallyLegalOp <cuf::AllocOp>([](::cuf::AllocOp op) {
460- return !mlir::isa<fir::BaseBoxType>(op.getInType ());
461- });
462- target.addDynamicallyLegalOp <cuf::FreeOp>([](::cuf::FreeOp op) {
463- if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
464- op.getDevptr ().getType ())) {
465- return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy ());
466- }
467- return true ;
468- });
469512 target.addDynamicallyLegalOp <cuf::DataTransferOp>(
470513 [](::cuf::DataTransferOp op) {
471514 mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
0 commit comments