@@ -62,14 +62,21 @@ struct FixupTy {
6262 FixupTy (Codes code, std::size_t index,
6363 std::function<void (mlir::func::FuncOp)> &&finalizer)
6464 : code{code}, index{index}, finalizer{finalizer} {}
65+ FixupTy (Codes code, std::size_t index,
66+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
67+ : code{code}, index{index}, gpuFinalizer{finalizer} {}
6568 FixupTy (Codes code, std::size_t index, std::size_t second,
6669 std::function<void (mlir::func::FuncOp)> &&finalizer)
6770 : code{code}, index{index}, second{second}, finalizer{finalizer} {}
71+ FixupTy (Codes code, std::size_t index, std::size_t second,
72+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
73+ : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
6874
6975 Codes code;
7076 std::size_t index;
7177 std::size_t second{};
7278 std::optional<std::function<void (mlir::func::FuncOp)>> finalizer{};
79+ std::optional<std::function<void (mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
7380}; // namespace
7481
7582// / Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -719,12 +726,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
719726 if (targetFeaturesAttr)
720727 fn->setAttr (" target_features" , targetFeaturesAttr);
721728
722- convertSignature (fn);
729+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp> (fn);
723730 }
724731
725- for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>())
732+ for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>()) {
726733 for (auto fn : gpuMod.getOps <mlir::func::FuncOp>())
727- convertSignature (fn);
734+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
735+ for (auto fn : gpuMod.getOps <mlir::gpu::GPUFuncOp>())
736+ convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
737+ }
728738
729739 return mlir::success ();
730740 }
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770780
771781 // / Determine if the signature has host associations. The host association
772782 // / argument may need special target specific rewriting.
773- static bool hasHostAssociations (mlir::func::FuncOp func) {
783+ template <typename OpTy>
784+ static bool hasHostAssociations (OpTy func) {
774785 std::size_t end = func.getFunctionType ().getInputs ().size ();
775786 for (std::size_t i = 0 ; i < end; ++i)
776- if (func.getArgAttrOfType <mlir::UnitAttr>(i, fir::getHostAssocAttrName ()))
787+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
788+ i, fir::getHostAssocAttrName ()))
777789 return true ;
778790 return false ;
779791 }
780792
781793 // / Rewrite the signatures and body of the `FuncOp`s in the module for
782794 // / the immediately subsequent target code gen.
783- void convertSignature (mlir::func::FuncOp func) {
795+ template <typename ReturnOpTy, typename FuncOpTy>
796+ void convertSignature (FuncOpTy func) {
784797 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
785798 if (hasPortableSignature (funcTy, func) && !hasHostAssociations (func))
786799 return ;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
805818 // Convert return value(s)
806819 for (auto ty : funcTy.getResults ())
807820 llvm::TypeSwitch<mlir::Type>(ty)
808- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
821+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
809822 if (noComplexConversion)
810823 newResTys.push_back (cmplx);
811824 else
812825 doComplexReturn (func, cmplx, newResTys, newInTyAndAttrs, fixups);
813826 })
814- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
827+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
815828 auto m = specifics->integerArgumentType (func.getLoc (), intTy);
816829 assert (m.size () == 1 );
817830 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -825,7 +838,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
825838 rewriter->getUnitAttr ()));
826839 newResTys.push_back (retTy);
827840 })
828- .Case <fir::RecordType>([&](fir::RecordType recTy) {
841+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
829842 doStructReturn (func, recTy, newResTys, newInTyAndAttrs, fixups);
830843 })
831844 .Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
840853 auto ty = e.value ();
841854 unsigned index = e.index ();
842855 llvm::TypeSwitch<mlir::Type>(ty)
843- .Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
856+ .template Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
844857 if (noCharacterConversion) {
845858 newInTyAndAttrs.push_back (
846859 fir::CodeGenSpecifics::getTypeAndAttr (boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
863876 }
864877 }
865878 })
866- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
879+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
867880 doComplexArg (func, cmplx, newInTyAndAttrs, fixups);
868881 })
869- .Case <mlir::TupleType>([&](mlir::TupleType tuple) {
882+ .template Case <mlir::TupleType>([&](mlir::TupleType tuple) {
870883 if (fir::isCharacterProcedureTuple (tuple)) {
871884 fixups.emplace_back (FixupTy::Codes::TrailingCharProc,
872885 newInTyAndAttrs.size (), trailingTys.size ());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
878891 fir::CodeGenSpecifics::getTypeAndAttr (ty));
879892 }
880893 })
881- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
894+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
882895 auto m = specifics->integerArgumentType (func.getLoc (), intTy);
883896 assert (m.size () == 1 );
884897 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
887900 if (!extensionAttrName.empty () &&
888901 isFuncWithCCallingConvention (func))
889902 fixups.emplace_back (FixupTy::Codes::ArgumentType, argNo,
890- [=](mlir::func::FuncOp func) {
903+ [=](FuncOpTy func) {
891904 func.setArgAttr (
892905 argNo, extensionAttrName,
893906 mlir::UnitAttr::get (func.getContext ()));
@@ -903,8 +916,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
903916 fir::CodeGenSpecifics::getTypeAndAttr (ty));
904917 });
905918
906- if (func.getArgAttrOfType <mlir::UnitAttr>(index,
907- fir::getHostAssocAttrName ())) {
919+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
920+ index, fir::getHostAssocAttrName ())) {
908921 extraAttrs.push_back (
909922 {newInTyAndAttrs.size () - 1 ,
910923 rewriter->getNamedAttr (" llvm.nest" , rewriter->getUnitAttr ())});
@@ -979,27 +992,27 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
979992 auto newArg =
980993 func.front ().insertArgument (fixup.index , fixupType, loc);
981994 offset++;
982- func.walk ([&](mlir::func::ReturnOp ret) {
995+ func.walk ([&](ReturnOpTy ret) {
983996 rewriter->setInsertionPoint (ret);
984997 auto oldOper = ret.getOperand (0 );
985998 auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
986999 auto cast =
9871000 rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
9881001 rewriter->create <fir::StoreOp>(loc, oldOper, cast);
989- rewriter->create <mlir::func::ReturnOp >(loc);
1002+ rewriter->create <ReturnOpTy >(loc);
9901003 ret.erase ();
9911004 });
9921005 } break ;
9931006 case FixupTy::Codes::ReturnType: {
9941007 // The function is still returning a value, but its type has likely
9951008 // changed to suit the target ABI convention.
996- func.walk ([&](mlir::func::ReturnOp ret) {
1009+ func.walk ([&](ReturnOpTy ret) {
9971010 rewriter->setInsertionPoint (ret);
9981011 auto oldOper = ret.getOperand (0 );
9991012 mlir::Value bitcast =
10001013 convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
10011014 /* inputMayBeBigger=*/ false );
1002- rewriter->create <mlir::func::ReturnOp >(loc, bitcast);
1015+ rewriter->create <ReturnOpTy >(loc, bitcast);
10031016 ret.erase ();
10041017 });
10051018 } break ;
@@ -1101,13 +1114,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11011114 }
11021115 }
11031116
1104- for (auto &fixup : fixups)
1105- if (fixup.finalizer )
1106- (*fixup.finalizer )(func);
1117+ for (auto &fixup : fixups) {
1118+ if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
1119+ if (fixup.finalizer )
1120+ (*fixup.finalizer )(func);
1121+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
1122+ if (fixup.gpuFinalizer )
1123+ (*fixup.gpuFinalizer )(func);
1124+ }
11071125 }
11081126
1109- template <typename Ty, typename FIXUPS>
1110- void doReturn (mlir::func::FuncOp func, Ty &newResTys,
1127+ template <typename OpTy, typename Ty, typename FIXUPS>
1128+ void doReturn (OpTy func, Ty &newResTys,
11111129 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11121130 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
11131131 assert (m.size () == 1 &&
@@ -1119,7 +1137,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11191137 unsigned argNo = newInTyAndAttrs.size ();
11201138 if (auto align = attr.getAlignment ())
11211139 fixups.emplace_back (
1122- FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1140+ FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
11231141 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
11241142 func.getFunctionType ().getInput (argNo));
11251143 func.setArgAttr (argNo, " llvm.sret" ,
@@ -1130,7 +1148,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11301148 });
11311149 else
11321150 fixups.emplace_back (FixupTy::Codes::ReturnAsStore, argNo,
1133- [=](mlir::func::FuncOp func) {
1151+ [=](OpTy func) {
11341152 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
11351153 func.getFunctionType ().getInput (argNo));
11361154 func.setArgAttr (argNo, " llvm.sret" ,
@@ -1141,8 +1159,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11411159 }
11421160 if (auto align = attr.getAlignment ())
11431161 fixups.emplace_back (
1144- FixupTy::Codes::ReturnType, newResTys.size (),
1145- [=](mlir::func::FuncOp func) {
1162+ FixupTy::Codes::ReturnType, newResTys.size (), [=](OpTy func) {
11461163 func.setArgAttr (
11471164 newResTys.size (), " llvm.align" ,
11481165 rewriter->getIntegerAttr (rewriter->getIntegerType (32 ), align));
@@ -1155,9 +1172,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11551172 // / Convert a complex return value. This can involve converting the return
11561173 // / value to a "hidden" first argument or packing the complex into a wide
11571174 // / GPR.
1158- template <typename Ty, typename FIXUPS>
1159- void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1160- Ty &newResTys,
1175+ template <typename OpTy, typename Ty, typename FIXUPS>
1176+ void doComplexReturn (OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
11611177 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11621178 FIXUPS &fixups) {
11631179 if (noComplexConversion) {
@@ -1169,9 +1185,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11691185 doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
11701186 }
11711187
1172- template <typename Ty, typename FIXUPS>
1173- void doStructReturn (mlir::func::FuncOp func, fir::RecordType recTy,
1174- Ty &newResTys,
1188+ template <typename OpTy, typename Ty, typename FIXUPS>
1189+ void doStructReturn (OpTy func, fir::RecordType recTy, Ty &newResTys,
11751190 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11761191 FIXUPS &fixups) {
11771192 if (noStructConversion) {
@@ -1182,12 +1197,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11821197 doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
11831198 }
11841199
1185- template <typename FIXUPS>
1186- void
1187- createFuncOpArgFixups (mlir::func::FuncOp func,
1188- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1189- fir::CodeGenSpecifics::Marshalling &argsInTys,
1190- FIXUPS &fixups) {
1200+ template <typename OpTy, typename FIXUPS>
1201+ void createFuncOpArgFixups (
1202+ OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1203+ fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
11911204 const auto fixupCode = argsInTys.size () > 1 ? FixupTy::Codes::Split
11921205 : FixupTy::Codes::ArgumentType;
11931206 for (auto e : llvm::enumerate (argsInTys)) {
@@ -1198,7 +1211,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11981211 if (attr.isByVal ()) {
11991212 if (auto align = attr.getAlignment ())
12001213 fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad, argNo,
1201- [=](mlir::func::FuncOp func) {
1214+ [=](OpTy func) {
12021215 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
12031216 func.getFunctionType ().getInput (argNo));
12041217 func.setArgAttr (argNo, " llvm.byval" ,
@@ -1210,8 +1223,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12101223 });
12111224 else
12121225 fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad,
1213- newInTyAndAttrs.size (),
1214- [=](mlir::func::FuncOp func) {
1226+ newInTyAndAttrs.size (), [=](OpTy func) {
12151227 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
12161228 func.getFunctionType ().getInput (argNo));
12171229 func.setArgAttr (argNo, " llvm.byval" ,
@@ -1220,7 +1232,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12201232 } else {
12211233 if (auto align = attr.getAlignment ())
12221234 fixups.emplace_back (
1223- fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1235+ fixupCode, argNo, index, [=](OpTy func) {
12241236 func.setArgAttr (argNo, " llvm.align" ,
12251237 rewriter->getIntegerAttr (
12261238 rewriter->getIntegerType (32 ), align));
@@ -1235,8 +1247,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12351247 // / Convert a complex argument value. This can involve storing the value to
12361248 // / a temporary memory location or factoring the value into two distinct
12371249 // / arguments.
1238- template <typename FIXUPS>
1239- void doComplexArg (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1250+ template <typename OpTy, typename FIXUPS>
1251+ void doComplexArg (OpTy func, mlir::ComplexType cmplx,
12401252 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12411253 FIXUPS &fixups) {
12421254 if (noComplexConversion) {
@@ -1248,8 +1260,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12481260 createFuncOpArgFixups (func, newInTyAndAttrs, cplxArgs, fixups);
12491261 }
12501262
1251- template <typename FIXUPS>
1252- void doStructArg (mlir::func::FuncOp func, fir::RecordType recTy,
1263+ template <typename OpTy, typename FIXUPS>
1264+ void doStructArg (OpTy func, fir::RecordType recTy,
12531265 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12541266 FIXUPS &fixups) {
12551267 if (noStructConversion) {
0 commit comments