@@ -726,14 +726,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
726726 if (targetFeaturesAttr)
727727 fn->setAttr (" target_features" , targetFeaturesAttr);
728728
729- convertSignature (fn);
729+ convertSignature<mlir::func::ReturnOp> (fn);
730730 }
731731
732732 for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>()) {
733733 for (auto fn : gpuMod.getOps <mlir::func::FuncOp>())
734- convertSignature (fn);
734+ convertSignature<mlir::func::ReturnOp> (fn);
735735 for (auto fn : gpuMod.getOps <mlir::gpu::GPUFuncOp>())
736- convertSignature (fn);
736+ convertSignature<mlir::gpu::ReturnOp> (fn);
737737 }
738738
739739 return mlir::success ();
@@ -792,8 +792,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
792792
793793 // / Rewrite the signatures and body of the `FuncOp`s in the module for
794794 // / the immediately subsequent target code gen.
795- template <typename OpTy >
796- void convertSignature (OpTy func) {
795+ template <typename ReturnOpTy, typename FuncOpTy >
796+ void convertSignature (FuncOpTy func) {
797797 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
798798 if (hasPortableSignature (funcTy, func) && !hasHostAssociations (func))
799799 return ;
@@ -900,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
900900 if (!extensionAttrName.empty () &&
901901 isFuncWithCCallingConvention (func))
902902 fixups.emplace_back (FixupTy::Codes::ArgumentType, argNo,
903- [=](OpTy func) {
903+ [=](FuncOpTy func) {
904904 func.setArgAttr (
905905 argNo, extensionAttrName,
906906 mlir::UnitAttr::get (func.getContext ()));
@@ -992,52 +992,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
992992 auto newArg =
993993 func.front ().insertArgument (fixup.index , fixupType, loc);
994994 offset++;
995- if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
996- func.walk ([&](mlir::func::ReturnOp ret) {
997- rewriter->setInsertionPoint (ret);
998- auto oldOper = ret.getOperand (0 );
999- auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
1000- auto cast =
1001- rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
1002- rewriter->create <fir::StoreOp>(loc, oldOper, cast);
1003- rewriter->create <mlir::func::ReturnOp>(loc);
1004- ret.erase ();
1005- });
1006- if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1007- func.walk ([&](mlir::gpu::ReturnOp ret) {
1008- rewriter->setInsertionPoint (ret);
1009- auto oldOper = ret.getOperand (0 );
1010- auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
1011- auto cast =
1012- rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
1013- rewriter->create <fir::StoreOp>(loc, oldOper, cast);
1014- rewriter->create <mlir::gpu::ReturnOp>(loc);
1015- ret.erase ();
1016- });
995+ func.walk ([&](ReturnOpTy ret) {
996+ rewriter->setInsertionPoint (ret);
997+ auto oldOper = ret.getOperand (0 );
998+ auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
999+ auto cast =
1000+ rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
1001+ rewriter->create <fir::StoreOp>(loc, oldOper, cast);
1002+ rewriter->create <ReturnOpTy>(loc);
1003+ ret.erase ();
1004+ });
10171005 } break ;
10181006 case FixupTy::Codes::ReturnType: {
10191007 // The function is still returning a value, but its type has likely
10201008 // changed to suit the target ABI convention.
1021- if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1022- func.walk ([&](mlir::func::ReturnOp ret) {
1023- rewriter->setInsertionPoint (ret);
1024- auto oldOper = ret.getOperand (0 );
1025- mlir::Value bitcast =
1026- convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1027- /* inputMayBeBigger=*/ false );
1028- rewriter->create <mlir::func::ReturnOp>(loc, bitcast);
1029- ret.erase ();
1030- });
1031- if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1032- func.walk ([&](mlir::gpu::ReturnOp ret) {
1033- rewriter->setInsertionPoint (ret);
1034- auto oldOper = ret.getOperand (0 );
1035- mlir::Value bitcast =
1036- convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1037- /* inputMayBeBigger=*/ false );
1038- rewriter->create <mlir::gpu::ReturnOp>(loc, bitcast);
1039- ret.erase ();
1040- });
1009+ func.walk ([&](ReturnOpTy ret) {
1010+ rewriter->setInsertionPoint (ret);
1011+ auto oldOper = ret.getOperand (0 );
1012+ mlir::Value bitcast =
1013+ convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1014+ /* inputMayBeBigger=*/ false );
1015+ rewriter->create <ReturnOpTy>(loc, bitcast);
1016+ ret.erase ();
1017+ });
10411018 } break ;
10421019 case FixupTy::Codes::Split: {
10431020 // The FIR argument has been split into a pair of distinct arguments
@@ -1138,10 +1115,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11381115 }
11391116
11401117 for (auto &fixup : fixups) {
1141- if constexpr (std::is_same_v<OpTy , mlir::func::FuncOp>)
1118+ if constexpr (std::is_same_v<FuncOpTy , mlir::func::FuncOp>)
11421119 if (fixup.finalizer )
11431120 (*fixup.finalizer )(func);
1144- if constexpr (std::is_same_v<OpTy , mlir::gpu::GPUFuncOp>)
1121+ if constexpr (std::is_same_v<FuncOpTy , mlir::gpu::GPUFuncOp>)
11451122 if (fixup.gpuFinalizer )
11461123 (*fixup.gpuFinalizer )(func);
11471124 }
0 commit comments