@@ -142,20 +142,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
142142
143143 mlir::ModuleOp getModule () { return getOperation (); }
144144
145- template <typename A , typename B, typename C >
145+ template <typename Ty , typename Callback >
146146 std::optional<std::function<mlir::Value(mlir::Operation *)>>
147- rewriteCallComplexResultType (
148- mlir::Location loc, A ty, B &newResTys,
149- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
150- mlir::Value &savedStackPtr) {
151- if (noComplexConversion) {
152- newResTys.push_back (ty);
153- return std::nullopt ;
154- }
155- auto m = specifics->complexReturnType (loc, ty.getElementType ());
156- // Currently targets mandate COMPLEX is a single aggregate or packed
157- // scalar, including the sret case.
158- assert (m.size () == 1 && " target of complex return not supported" );
147+ rewriteCallResultType (mlir::Location loc, mlir::Type originalResTy,
148+ Ty &newResTys,
149+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
150+ Callback &newOpers, mlir::Value &savedStackPtr,
151+ fir::CodeGenSpecifics::Marshalling &m) {
152+ // Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
153+ // packed scalar, including the sret case.
154+ assert (m.size () == 1 && " return type not supported on this target" );
159155 auto resTy = std::get<mlir::Type>(m[0 ]);
160156 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0 ]);
161157 if (attr.isSRet ()) {
@@ -170,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
170166 newInTyAndAttrs.push_back (m[0 ]);
171167 newOpers.push_back (stack);
172168 return [=](mlir::Operation *) -> mlir::Value {
173- auto memTy = fir::ReferenceType::get (ty );
169+ auto memTy = fir::ReferenceType::get (originalResTy );
174170 auto cast = rewriter->create <fir::ConvertOp>(loc, memTy, stack);
175171 return rewriter->create <fir::LoadOp>(loc, cast);
176172 };
@@ -180,11 +176,41 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180176 // We are going to generate an alloca, so save the stack pointer.
181177 if (!savedStackPtr)
182178 savedStackPtr = genStackSave (loc);
183- return this ->convertValueInMemory (loc, call->getResult (0 ), ty ,
179+ return this ->convertValueInMemory (loc, call->getResult (0 ), originalResTy ,
184180 /* inputMayBeBigger=*/ true );
185181 };
186182 }
187183
184+ template <typename Ty, typename Callback>
185+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
186+ rewriteCallComplexResultType (
187+ mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
188+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
189+ mlir::Value &savedStackPtr) {
190+ if (noComplexConversion) {
191+ newResTys.push_back (ty);
192+ return std::nullopt ;
193+ }
194+ auto m = specifics->complexReturnType (loc, ty.getElementType ());
195+ return rewriteCallResultType (loc, ty, newResTys, newInTyAndAttrs, newOpers,
196+ savedStackPtr, m);
197+ }
198+
199+ template <typename Ty, typename Callback>
200+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
201+ rewriteCallStructResultType (
202+ mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
203+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
204+ mlir::Value &savedStackPtr) {
205+ if (noStructConversion) {
206+ newResTys.push_back (recTy);
207+ return std::nullopt ;
208+ }
209+ auto m = specifics->structReturnType (loc, recTy);
210+ return rewriteCallResultType (loc, recTy, newResTys, newInTyAndAttrs,
211+ newOpers, savedStackPtr, m);
212+ }
213+
188214 void passArgumentOnStackOrWithNewType (
189215 mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
190216 mlir::Type oldType, mlir::Value oper,
@@ -356,6 +382,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
356382 newInTyAndAttrs, newOpers,
357383 savedStackPtr);
358384 })
385+ .template Case <fir::RecordType>([&](fir::RecordType recTy) {
386+ wrap = rewriteCallStructResultType (loc, recTy, newResTys,
387+ newInTyAndAttrs, newOpers,
388+ savedStackPtr);
389+ })
359390 .Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
360391 } else if (fnTy.getResults ().size () > 1 ) {
361392 TODO (loc, " multiple results not supported yet" );
@@ -562,6 +593,24 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
562593 }
563594 }
564595
596+ template <typename Ty>
597+ void
598+ lowerStructSignatureRes (mlir::Location loc, fir::RecordType recTy,
599+ Ty &newResTys,
600+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
601+ if (noComplexConversion) {
602+ newResTys.push_back (recTy);
603+ return ;
604+ } else {
605+ for (auto &tup : specifics->structReturnType (loc, recTy)) {
606+ if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet ())
607+ newInTyAndAttrs.push_back (tup);
608+ else
609+ newResTys.push_back (std::get<mlir::Type>(tup));
610+ }
611+ }
612+ }
613+
565614 void
566615 lowerStructSignatureArg (mlir::Location loc, fir::RecordType recTy,
567616 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
@@ -595,6 +644,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
595644 .Case <mlir::ComplexType>([&](mlir::ComplexType ty) {
596645 lowerComplexSignatureRes (loc, ty, newResTys, newInTyAndAttrs);
597646 })
647+ .Case <fir::RecordType>([&](fir::RecordType ty) {
648+ lowerStructSignatureRes (loc, ty, newResTys, newInTyAndAttrs);
649+ })
598650 .Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
599651 }
600652 llvm::SmallVector<mlir::Type> trailingInTys;
@@ -696,7 +748,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
696748 for (auto ty : func.getResults ())
697749 if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
698750 (fir::isa_complex (ty) && !noComplexConversion) ||
699- (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv)) {
751+ (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
752+ (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
700753 LLVM_DEBUG (llvm::dbgs () << " rewrite " << signature << " for target\n " );
701754 return false ;
702755 }
@@ -770,6 +823,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
770823 rewriter->getUnitAttr ()));
771824 newResTys.push_back (retTy);
772825 })
826+ .Case <fir::RecordType>([&](fir::RecordType recTy) {
827+ doStructReturn (func, recTy, newResTys, newInTyAndAttrs, fixups);
828+ })
773829 .Default ([&](mlir::Type ty) { newResTys.push_back (ty); });
774830
775831 // Saved potential shift in argument. Handling of result can add arguments
@@ -1062,21 +1118,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
10621118 return false ;
10631119 }
10641120
1065- // / Convert a complex return value. This can involve converting the return
1066- // / value to a "hidden" first argument or packing the complex into a wide
1067- // / GPR.
10681121 template <typename Ty, typename FIXUPS>
1069- void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1070- Ty &newResTys,
1071- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1072- FIXUPS &fixups) {
1073- if (noComplexConversion) {
1074- newResTys.push_back (cmplx);
1075- return ;
1076- }
1077- auto m =
1078- specifics->complexReturnType (func.getLoc (), cmplx.getElementType ());
1079- assert (m.size () == 1 );
1122+ void doReturn (mlir::func::FuncOp func, Ty &newResTys,
1123+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1124+ FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
1125+ assert (m.size () == 1 &&
1126+ " expect result to be turned into single argument or result so far" );
10801127 auto &tup = m[0 ];
10811128 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
10821129 auto argTy = std::get<mlir::Type>(tup);
@@ -1117,6 +1164,36 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
11171164 newResTys.push_back (argTy);
11181165 }
11191166
1167+ // / Convert a complex return value. This can involve converting the return
1168+ // / value to a "hidden" first argument or packing the complex into a wide
1169+ // / GPR.
1170+ template <typename Ty, typename FIXUPS>
1171+ void doComplexReturn (mlir::func::FuncOp func, mlir::ComplexType cmplx,
1172+ Ty &newResTys,
1173+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1174+ FIXUPS &fixups) {
1175+ if (noComplexConversion) {
1176+ newResTys.push_back (cmplx);
1177+ return ;
1178+ }
1179+ auto m =
1180+ specifics->complexReturnType (func.getLoc (), cmplx.getElementType ());
1181+ doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1182+ }
1183+
1184+ template <typename Ty, typename FIXUPS>
1185+ void doStructReturn (mlir::func::FuncOp func, fir::RecordType recTy,
1186+ Ty &newResTys,
1187+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
1188+ FIXUPS &fixups) {
1189+ if (noStructConversion) {
1190+ newResTys.push_back (recTy);
1191+ return ;
1192+ }
1193+ auto m = specifics->structReturnType (func.getLoc (), recTy);
1194+ doReturn (func, newResTys, newInTyAndAttrs, fixups, m);
1195+ }
1196+
11201197 template <typename FIXUPS>
11211198 void
11221199 createFuncOpArgFixups (mlir::func::FuncOp func,
0 commit comments