@@ -234,6 +234,60 @@ class SaveResultOpConversion
234234 }
235235};
236236
237+ template <typename OpTy>
238+ static mlir::LogicalResult
239+ processReturnLikeOp (OpTy ret, mlir::Value newArg,
240+ mlir::PatternRewriter &rewriter) {
241+ auto loc = ret.getLoc ();
242+ rewriter.setInsertionPoint (ret);
243+ mlir::Value resultValue = ret.getOperand (0 );
244+ fir::LoadOp resultLoad;
245+ mlir::Value resultStorage;
246+ // Identify result local storage.
247+ if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
248+ resultLoad = load;
249+ resultStorage = load.getMemref ();
250+ // The result alloca may be behind a fir.declare, if any.
251+ if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
252+ resultStorage = declare.getMemref ();
253+ }
254+ // Replace old local storage with new storage argument, unless
255+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
256+ // type is updated to return void* (no new argument is passed).
257+ if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
258+ auto module = ret->template getParentOfType <mlir::ModuleOp>();
259+ FirOpBuilder builder (rewriter, module );
260+ mlir::Value cptr = resultValue;
261+ if (resultLoad) {
262+ // Replace whole derived type load by component load.
263+ cptr = resultLoad.getMemref ();
264+ rewriter.setInsertionPoint (resultLoad);
265+ }
266+ mlir::Value newResultValue =
267+ fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
268+ newResultValue = builder.createConvert (
269+ loc, getVoidPtrType (ret.getContext ()), newResultValue);
270+ rewriter.setInsertionPoint (ret);
271+ rewriter.replaceOpWithNewOp <OpTy>(ret, mlir::ValueRange{newResultValue});
272+ } else if (resultStorage) {
273+ resultStorage.replaceAllUsesWith (newArg);
274+ rewriter.replaceOpWithNewOp <OpTy>(ret);
275+ } else {
276+ // The result storage may have been optimized out by a memory to
277+ // register pass, this is possible for fir.box results, or fir.record
278+ // with no length parameters. Simply store the result in the result
279+ // storage. at the return point.
280+ rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
281+ rewriter.replaceOpWithNewOp <OpTy>(ret);
282+ }
283+ // Delete result old local storage if unused.
284+ if (resultStorage)
285+ if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
286+ if (alloc->use_empty ())
287+ rewriter.eraseOp (alloc);
288+ return mlir::success ();
289+ }
290+
237291class ReturnOpConversion : public mlir ::OpRewritePattern<mlir::func::ReturnOp> {
238292public:
239293 using OpRewritePattern::OpRewritePattern;
@@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
242296 llvm::LogicalResult
243297 matchAndRewrite (mlir::func::ReturnOp ret,
244298 mlir::PatternRewriter &rewriter) const override {
245- auto loc = ret.getLoc ();
246- rewriter.setInsertionPoint (ret);
247- mlir::Value resultValue = ret.getOperand (0 );
248- fir::LoadOp resultLoad;
249- mlir::Value resultStorage;
250- // Identify result local storage.
251- if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
252- resultLoad = load;
253- resultStorage = load.getMemref ();
254- // The result alloca may be behind a fir.declare, if any.
255- if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
256- resultStorage = declare.getMemref ();
257- }
258- // Replace old local storage with new storage argument, unless
259- // the derived type is C_PTR/C_FUN_PTR, in which case the return
260- // type is updated to return void* (no new argument is passed).
261- if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
262- auto module = ret->getParentOfType <mlir::ModuleOp>();
263- FirOpBuilder builder (rewriter, module );
264- mlir::Value cptr = resultValue;
265- if (resultLoad) {
266- // Replace whole derived type load by component load.
267- cptr = resultLoad.getMemref ();
268- rewriter.setInsertionPoint (resultLoad);
269- }
270- mlir::Value newResultValue =
271- fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
272- newResultValue = builder.createConvert (
273- loc, getVoidPtrType (ret.getContext ()), newResultValue);
274- rewriter.setInsertionPoint (ret);
275- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
276- ret, mlir::ValueRange{newResultValue});
277- } else if (resultStorage) {
278- resultStorage.replaceAllUsesWith (newArg);
279- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
280- } else {
281- // The result storage may have been optimized out by a memory to
282- // register pass, this is possible for fir.box results, or fir.record
283- // with no length parameters. Simply store the result in the result
284- // storage. at the return point.
285- rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
286- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
287- }
288- // Delete result old local storage if unused.
289- if (resultStorage)
290- if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
291- if (alloc->use_empty ())
292- rewriter.eraseOp (alloc);
293- return mlir::success ();
299+ return processReturnLikeOp (ret, newArg, rewriter);
300+ }
301+
302+ private:
303+ mlir::Value newArg;
304+ };
305+
306+ class GPUReturnOpConversion
307+ : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
308+ public:
309+ using OpRewritePattern::OpRewritePattern;
310+ GPUReturnOpConversion (mlir::MLIRContext *context, mlir::Value newArg)
311+ : OpRewritePattern(context), newArg{newArg} {}
312+ llvm::LogicalResult
313+ matchAndRewrite (mlir::gpu::ReturnOp ret,
314+ mlir::PatternRewriter &rewriter) const override {
315+ return processReturnLikeOp (ret, newArg, rewriter);
294316 }
295317
296318private:
@@ -373,6 +395,9 @@ class AbstractResultOpt
373395 patterns.insert <ReturnOpConversion>(context, newArg);
374396 target.addDynamicallyLegalOp <mlir::func::ReturnOp>(
375397 [](mlir::func::ReturnOp ret) { return ret.getOperands ().empty (); });
398+ patterns.insert <GPUReturnOpConversion>(context, newArg);
399+ target.addDynamicallyLegalOp <mlir::gpu::ReturnOp>(
400+ [](mlir::gpu::ReturnOp ret) { return ret.getOperands ().empty (); });
376401 assert (func.getFunctionType () ==
377402 getNewFunctionType (funcTy, shouldBoxResult));
378403 } else {
0 commit comments