@@ -62,14 +62,21 @@ struct FixupTy {
6262 FixupTy (Codes code, std::size_t index,
6363 std::function<void (mlir::func::FuncOp)> &&finalizer)
6464 : code{code}, index{index}, finalizer{finalizer} {}
65+ FixupTy (Codes code, std::size_t index,
66+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
67+ : code{code}, index{index}, gpuFinalizer{finalizer} {}
6568 FixupTy (Codes code, std::size_t index, std::size_t second,
6669 std::function<void (mlir::func::FuncOp)> &&finalizer)
6770 : code{code}, index{index}, second{second}, finalizer{finalizer} {}
71+ FixupTy (Codes code, std::size_t index, std::size_t second,
72+ std::function<void (mlir::gpu::GPUFuncOp)> &&finalizer)
73+ : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
6874
6975 Codes code;
7076 std::size_t index;
7177 std::size_t second{};
7278 std::optional<std::function<void (mlir::func::FuncOp)>> finalizer{};
79+ std::optional<std::function<void (mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
7380}; // namespace
7481
7582// / Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
722729 convertSignature (fn);
723730 }
724731
725- for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>())
732+ for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>()) {
726733 for (auto fn : gpuMod.getOps <mlir::func::FuncOp>())
727734 convertSignature (fn);
735+ for (auto fn : gpuMod.getOps <mlir::gpu::GPUFuncOp>())
736+ convertSignature (fn);
737+ }
728738
729739 return mlir::success ();
730740 }
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770780
771781 // / Determine if the signature has host associations. The host association
772782 // / argument may need special target specific rewriting.
773- static bool hasHostAssociations (mlir::func::FuncOp func) {
783+ template <typename OpTy>
784+ static bool hasHostAssociations (OpTy func) {
774785 std::size_t end = func.getFunctionType ().getInputs ().size ();
775786 for (std::size_t i = 0 ; i < end; ++i)
776- if (func.getArgAttrOfType <mlir::UnitAttr>(i, fir::getHostAssocAttrName ()))
787+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
788+ i, fir::getHostAssocAttrName ()))
777789 return true ;
778790 return false ;
779791 }
780792
781793 // / Rewrite the signatures and body of the `FuncOp`s in the module for
782794 // / the immediately subsequent target code gen.
783- void convertSignature (mlir::func::FuncOp func) {
795+ template <typename OpTy>
796+ void convertSignature (OpTy func) {
784797 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType ());
785798 if (hasPortableSignature (funcTy, func) && !hasHostAssociations (func))
786799 return ;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
805818 // Convert return value(s)
806819 for (auto ty : funcTy.getResults ())
807820 llvm::TypeSwitch<mlir::Type>(ty)
808- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
821+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
809822 if (noComplexConversion)
810823 newResTys.push_back (cmplx);
811824 else
812825 doComplexReturn (func, cmplx, newResTys, newInTyAndAttrs, fixups);
813826 })
814- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
827+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
815828 auto m = specifics->integerArgumentType (func.getLoc (), intTy);
816829 assert (m.size () == 1 );
817830 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
825838 rewriter->getUnitAttr ()));
826839 newResTys.push_back (retTy);
827840 })
828- .Case <fir::RecordType>([&](fir::RecordType recTy) {
841+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
829842 doStructReturn (func, recTy, newResTys, newInTyAndAttrs, fixups);
830843 })
831- .Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
844+ .template Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
832845
833846 // Saved potential shift in argument. Handling of result can add arguments
834847 // at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
840853 auto ty = e.value ();
841854 unsigned index = e.index ();
842855 llvm::TypeSwitch<mlir::Type>(ty)
843- .Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
856+ .template Case <fir::BoxCharType>([&](fir::BoxCharType boxTy) {
844857 if (noCharacterConversion) {
845858 newInTyAndAttrs.push_back (
846859 fir::CodeGenSpecifics::getTypeAndAttr (boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
863876 }
864877 }
865878 })
866- .Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
879+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
867880 doComplexArg (func, cmplx, newInTyAndAttrs, fixups);
868881 })
869- .Case <mlir::TupleType>([&](mlir::TupleType tuple) {
882+ .template Case <mlir::TupleType>([&](mlir::TupleType tuple) {
870883 if (fir::isCharacterProcedureTuple (tuple)) {
871884 fixups.emplace_back (FixupTy::Codes::TrailingCharProc,
872885 newInTyAndAttrs.size (), trailingTys.size ());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
878891 fir::CodeGenSpecifics::getTypeAndAttr (ty));
879892 }
880893 })
881- .Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
894+ .template Case <mlir::IntegerType>([&](mlir::IntegerType intTy) {
882895 auto m = specifics->integerArgumentType (func.getLoc (), intTy);
883896 assert (m.size () == 1 );
884897 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
887900 if (!extensionAttrName.empty () &&
888901 isFuncWithCCallingConvention (func))
889902 fixups.emplace_back (FixupTy::Codes::ArgumentType, argNo,
890- [=](mlir::func::FuncOp func) {
903+ [=](OpTy func) {
891904 func.setArgAttr (
892905 argNo, extensionAttrName,
893906 mlir::UnitAttr::get (func.getContext ()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
898911 .template Case <fir::RecordType>([&](fir::RecordType recTy) {
899912 doStructArg (func, recTy, newInTyAndAttrs, fixups);
900913 })
901- .Default ([&](mlir::Type ty) {
914+ .template Default ([&](mlir::Type ty) {
902915 newInTyAndAttrs.push_back (
903916 fir::CodeGenSpecifics::getTypeAndAttr (ty));
904917 });
905918
906- if (func.getArgAttrOfType <mlir::UnitAttr>(index,
907- fir::getHostAssocAttrName ())) {
919+ if (func.template getArgAttrOfType <mlir::UnitAttr>(
920+ index, fir::getHostAssocAttrName ())) {
908921 extraAttrs.push_back (
909922 {newInTyAndAttrs.size () - 1 ,
910923 rewriter->getNamedAttr (" llvm.nest" , rewriter->getUnitAttr ())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
979992 auto newArg =
980993 func.front ().insertArgument (fixup.index , fixupType, loc);
981994 offset++;
982- func.walk ([&](mlir::func::ReturnOp ret) {
983- rewriter->setInsertionPoint (ret);
984- auto oldOper = ret.getOperand (0 );
985- auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
986- auto cast =
987- rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
988- rewriter->create <fir::StoreOp>(loc, oldOper, cast);
989- rewriter->create <mlir::func::ReturnOp>(loc);
990- ret.erase ();
991- });
995+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
996+ func.walk ([&](mlir::func::ReturnOp ret) {
997+ rewriter->setInsertionPoint (ret);
998+ auto oldOper = ret.getOperand (0 );
999+ auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
1000+ auto cast =
1001+ rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
1002+ rewriter->create <fir::StoreOp>(loc, oldOper, cast);
1003+ rewriter->create <mlir::func::ReturnOp>(loc);
1004+ ret.erase ();
1005+ });
1006+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1007+ func.walk ([&](mlir::gpu::ReturnOp ret) {
1008+ rewriter->setInsertionPoint (ret);
1009+ auto oldOper = ret.getOperand (0 );
1010+ auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
1011+ auto cast =
1012+ rewriter->create <fir::ConvertOp>(loc, oldOperTy, newArg);
1013+ rewriter->create <fir::StoreOp>(loc, oldOper, cast);
1014+ rewriter->create <mlir::gpu::ReturnOp>(loc);
1015+ ret.erase ();
1016+ });
9921017 } break ;
9931018 case FixupTy::Codes::ReturnType: {
9941019 // The function is still returning a value, but its type has likely
9951020 // changed to suit the target ABI convention.
996- func.walk ([&](mlir::func::ReturnOp ret) {
997- rewriter->setInsertionPoint (ret);
998- auto oldOper = ret.getOperand (0 );
999- mlir::Value bitcast =
1000- convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1001- /* inputMayBeBigger=*/ false );
1002- rewriter->create <mlir::func::ReturnOp>(loc, bitcast);
1003- ret.erase ();
1004- });
1021+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1022+ func.walk ([&](mlir::func::ReturnOp ret) {
1023+ rewriter->setInsertionPoint (ret);
1024+ auto oldOper = ret.getOperand (0 );
1025+ mlir::Value bitcast =
1026+ convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1027+ /* inputMayBeBigger=*/ false );
1028+ rewriter->create <mlir::func::ReturnOp>(loc, bitcast);
1029+ ret.erase ();
1030+ });
1031+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1032+ func.walk ([&](mlir::gpu::ReturnOp ret) {
1033+ rewriter->setInsertionPoint (ret);
1034+ auto oldOper = ret.getOperand (0 );
1035+ mlir::Value bitcast =
1036+ convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
1037+ /* inputMayBeBigger=*/ false );
1038+ rewriter->create <mlir::gpu::ReturnOp>(loc, bitcast);
1039+ ret.erase ();
1040+ });
10051041 } break ;
10061042 case FixupTy::Codes::Split: {
10071043 // The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11011137 }
11021138 }
11031139
1104- for (auto &fixup : fixups)
1105- if (fixup.finalizer )
1106- (*fixup.finalizer )(func);
1140+ for (auto &fixup : fixups) {
1141+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
1142+ if (fixup.finalizer )
1143+ (*fixup.finalizer )(func);
1144+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
1145+ if (fixup.gpuFinalizer )
1146+ (*fixup.gpuFinalizer )(func);
1147+ }
11071148 }
11081149
1109- template <typename Ty, typename FIXUPS>
1110- void doReturn (mlir::func::FuncOp func, Ty &newResTys,
1150+ template <typename OpTy, typename Ty, typename FIXUPS>
1151+ void doReturn (OpTy func, Ty &newResTys,
11111152 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11121153 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
11131154 assert (m.size () == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11191160 unsigned argNo = newInTyAndAttrs.size ();
11201161 if (auto align = attr.getAlignment ())
11211162 fixups.emplace_back (
1122- FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
1163+ FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
11231164 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
11241165 func.getFunctionType ().getInput (argNo));
11251166 func.setArgAttr (argNo, " llvm.sret" ,
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11301171 });
11311172 else
11321173 fixups.emplace_back (FixupTy::Codes::ReturnAsStore, argNo,
1133- [=](mlir::func::FuncOp func) {
1174+ [=](OpTy func) {
11341175 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
11351176 func.getFunctionType ().getInput (argNo));
11361177 func.setArgAttr (argNo, " llvm.sret" ,
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11411182 }
11421183 if (auto align = attr.getAlignment ())
11431184 fixups.emplace_back (
1144- FixupTy::Codes::ReturnType, newResTys.size (),
1145- [=](mlir::func::FuncOp func) {
1185+ FixupTy::Codes::ReturnType, newResTys.size (), [=](OpTy func) {
11461186 func.setArgAttr (
11471187 newResTys.size (), " llvm.align" ,
11481188 rewriter->getIntegerAttr (rewriter->getIntegerType (32 ), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11551195 // / Convert a complex return value. This can involve converting the return
11561196 // / value to a "hidden" first argument or packing the complex into a wide
11571197 // / GPR.
1158- template <typename Ty, typename FIXUPS>
1159- void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1160- Ty &newResTys,
1198+ template <typename OpTy, typename Ty, typename FIXUPS>
1199+ void doComplexReturn (OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
11611200 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11621201 FIXUPS &fixups) {
11631202 if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11691208 doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
11701209 }
11711210
1172- template <typename Ty, typename FIXUPS>
1173- void doStructReturn (mlir::func::FuncOp func, fir::RecordType recTy,
1174- Ty &newResTys,
1211+ template <typename OpTy, typename Ty, typename FIXUPS>
1212+ void doStructReturn (OpTy func, fir::RecordType recTy, Ty &newResTys,
11751213 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
11761214 FIXUPS &fixups) {
11771215 if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11821220 doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
11831221 }
11841222
1185- template <typename FIXUPS>
1186- void
1187- createFuncOpArgFixups (mlir::func::FuncOp func,
1188- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1189- fir::CodeGenSpecifics::Marshalling &argsInTys,
1190- FIXUPS &fixups) {
1223+ template <typename OpTy, typename FIXUPS>
1224+ void createFuncOpArgFixups (
1225+ OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1226+ fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
11911227 const auto fixupCode = argsInTys.size () > 1 ? FixupTy::Codes::Split
11921228 : FixupTy::Codes::ArgumentType;
11931229 for (auto e : llvm::enumerate (argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11981234 if (attr.isByVal ()) {
11991235 if (auto align = attr.getAlignment ())
12001236 fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad, argNo,
1201- [=](mlir::func::FuncOp func) {
1237+ [=](OpTy func) {
12021238 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
12031239 func.getFunctionType ().getInput (argNo));
12041240 func.setArgAttr (argNo, " llvm.byval" ,
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12101246 });
12111247 else
12121248 fixups.emplace_back (FixupTy::Codes::ArgumentAsLoad,
1213- newInTyAndAttrs.size (),
1214- [=](mlir::func::FuncOp func) {
1249+ newInTyAndAttrs.size (), [=](OpTy func) {
12151250 auto elemType = fir::dyn_cast_ptrOrBoxEleTy (
12161251 func.getFunctionType ().getInput (argNo));
12171252 func.setArgAttr (argNo, " llvm.byval" ,
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12201255 } else {
12211256 if (auto align = attr.getAlignment ())
12221257 fixups.emplace_back (
1223- fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
1258+ fixupCode, argNo, index, [=](OpTy func) {
12241259 func.setArgAttr (argNo, " llvm.align" ,
12251260 rewriter->getIntegerAttr (
12261261 rewriter->getIntegerType (32 ), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12351270 // / Convert a complex argument value. This can involve storing the value to
12361271 // / a temporary memory location or factoring the value into two distinct
12371272 // / arguments.
1238- template <typename FIXUPS>
1239- void doComplexArg (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1273+ template <typename OpTy, typename FIXUPS>
1274+ void doComplexArg (OpTy func, mlir::ComplexType cmplx,
12401275 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12411276 FIXUPS &fixups) {
12421277 if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12481283 createFuncOpArgFixups (func, newInTyAndAttrs, cplxArgs, fixups);
12491284 }
12501285
1251- template <typename FIXUPS>
1252- void doStructArg (mlir::func::FuncOp func, fir::RecordType recTy,
1286+ template <typename OpTy, typename FIXUPS>
1287+ void doStructArg (OpTy func, fir::RecordType recTy,
12531288 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
12541289 FIXUPS &fixups) {
12551290 if (noStructConversion) {
0 commit comments