@@ -180,11 +180,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
180180 // We are going to generate an alloca, so save the stack pointer.
181181 if (!savedStackPtr)
182182 savedStackPtr = genStackSave (loc);
183- auto mem = rewriter->create <fir::AllocaOp>(loc, resTy);
184- rewriter->create <fir::StoreOp>(loc, call->getResult (0 ), mem);
185- auto memTy = fir::ReferenceType::get (ty);
186- auto cast = rewriter->create <fir::ConvertOp>(loc, memTy, mem);
187- return rewriter->create <fir::LoadOp>(loc, cast);
183+ return this ->convertValueInMemory (loc, call->getResult (0 ), ty,
184+ /* inputMayBeBigger=*/ true );
188185 };
189186 }
190187
@@ -195,7 +192,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
195192 mlir::Value &savedStackPtr) {
196193 auto resTy = std::get<mlir::Type>(newTypeAndAttr);
197194 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
198- auto oldRefTy = fir::ReferenceType::get (oldType);
199195 // We are going to generate an alloca, so save the stack pointer.
200196 if (!savedStackPtr)
201197 savedStackPtr = genStackSave (loc);
@@ -206,11 +202,83 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
206202 mem = rewriter->create <fir::ConvertOp>(loc, resTy, mem);
207203 newOpers.push_back (mem);
208204 } else {
209- auto mem = rewriter->create <fir::AllocaOp>(loc, resTy);
205+ mlir::Value bitcast =
206+ convertValueInMemory (loc, oper, resTy, /* inputMayBeBigger=*/ false );
207+ newOpers.push_back (bitcast);
208+ }
209+ }
210+
211+ // Do a bitcast (convert a value via its memory representation).
212+ // The input and output types may have different storage sizes,
213+ // "inputMayBeBigger" should be set to indicate which of the input or
214+ // output type may be bigger in order for the load/store to be safe.
215+ // The mismatch comes from the fact that the LLVM register used for passing
216+ // may be bigger than the value being passed (e.g., passing
217+ // a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
218+ mlir::Value convertValueInMemory (mlir::Location loc, mlir::Value value,
219+ mlir::Type newType, bool inputMayBeBigger) {
220+ if (inputMayBeBigger) {
221+ auto newRefTy = fir::ReferenceType::get (newType);
222+ auto mem = rewriter->create <fir::AllocaOp>(loc, value.getType ());
223+ rewriter->create <fir::StoreOp>(loc, value, mem);
224+ auto cast = rewriter->create <fir::ConvertOp>(loc, newRefTy, mem);
225+ return rewriter->create <fir::LoadOp>(loc, cast);
226+ } else {
227+ auto oldRefTy = fir::ReferenceType::get (value.getType ());
228+ auto mem = rewriter->create <fir::AllocaOp>(loc, newType);
210229 auto cast = rewriter->create <fir::ConvertOp>(loc, oldRefTy, mem);
211- rewriter->create <fir::StoreOp>(loc, oper, cast);
212- newOpers.push_back (rewriter->create <fir::LoadOp>(loc, mem));
230+ rewriter->create <fir::StoreOp>(loc, value, cast);
231+ return rewriter->create <fir::LoadOp>(loc, mem);
232+ }
233+ }
234+
235+ void passSplitArgument (mlir::Location loc,
236+ fir::CodeGenSpecifics::Marshalling splitArgs,
237+ mlir::Type oldType, mlir::Value oper,
238+ llvm::SmallVectorImpl<mlir::Value> &newOpers,
239+ mlir::Value &savedStackPtr) {
240+ // COMPLEX or struct argument split into separate arguments
241+ if (!fir::isa_complex (oldType)) {
242+ // Cast original operand to a tuple of the new arguments
243+ // via memory.
244+ llvm::SmallVector<mlir::Type> partTypes;
245+ for (auto argPart : splitArgs)
246+ partTypes.push_back (std::get<mlir::Type>(argPart));
247+ mlir::Type tupleType =
248+ mlir::TupleType::get (oldType.getContext (), partTypes);
249+ if (!savedStackPtr)
250+ savedStackPtr = genStackSave (loc);
251+ oper = convertValueInMemory (loc, oper, tupleType,
252+ /* inputMayBeBigger=*/ false );
253+ }
254+ auto iTy = rewriter->getIntegerType (32 );
255+ for (auto e : llvm::enumerate (splitArgs)) {
256+ auto &tup = e.value ();
257+ auto ty = std::get<mlir::Type>(tup);
258+ auto index = e.index ();
259+ auto idx = rewriter->getIntegerAttr (iTy, index);
260+ auto val = rewriter->create <fir::ExtractValueOp>(
261+ loc, ty, oper, rewriter->getArrayAttr (idx));
262+ newOpers.push_back (val);
263+ }
264+ }
265+
266+ void rewriteCallOperands (
267+ mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
268+ mlir::Type originalArgTy, mlir::Value oper,
269+ llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
270+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
271+ if (passArgAs.size () == 1 ) {
272+ // COMPLEX or derived type is passed as a single argument.
273+ passArgumentOnStackOrWithNewType (loc, passArgAs[0 ], originalArgTy, oper,
274+ newOpers, savedStackPtr);
275+ } else {
276+ // COMPLEX or derived type is split into separate arguments
277+ passSplitArgument (loc, passArgAs, originalArgTy, oper, newOpers,
278+ savedStackPtr);
213279 }
280+ newInTyAndAttrs.insert (newInTyAndAttrs.end (), passArgAs.begin (),
281+ passArgAs.end ());
214282 }
215283
216284 template <typename CPLX>
@@ -224,28 +292,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
224292 newOpers.push_back (oper);
225293 return ;
226294 }
227-
228295 auto m = specifics->complexArgumentType (loc, ty.getElementType ());
229- if (m.size () == 1 ) {
230- // COMPLEX is a single aggregate
231- passArgumentOnStackOrWithNewType (loc, m[0 ], ty, oper, newOpers,
232- savedStackPtr);
233- newInTyAndAttrs.push_back (m[0 ]);
234- } else {
235- assert (m.size () == 2 );
236- // COMPLEX is split into 2 separate arguments
237- auto iTy = rewriter->getIntegerType (32 );
238- for (auto e : llvm::enumerate (m)) {
239- auto &tup = e.value ();
240- auto ty = std::get<mlir::Type>(tup);
241- auto index = e.index ();
242- auto idx = rewriter->getIntegerAttr (iTy, index);
243- auto val = rewriter->create <fir::ExtractValueOp>(
244- loc, ty, oper, rewriter->getArrayAttr (idx));
245- newInTyAndAttrs.push_back (tup);
246- newOpers.push_back (val);
247- }
248- }
296+ rewriteCallOperands (loc, m, ty, oper, newOpers, savedStackPtr,
297+ newInTyAndAttrs);
249298 }
250299
251300 void rewriteCallStructInputType (
@@ -260,11 +309,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
260309 }
261310 auto structArgs =
262311 specifics->structArgumentType (loc, recTy, newInTyAndAttrs);
263- if (structArgs.size () != 1 )
264- TODO (loc, " splitting BIND(C), VALUE derived type into several arguments" );
265- passArgumentOnStackOrWithNewType (loc, structArgs[0 ], recTy, oper, newOpers,
266- savedStackPtr);
267- structArgs.push_back (structArgs[0 ]);
312+ rewriteCallOperands (loc, structArgs, recTy, oper, newOpers, savedStackPtr,
313+ newInTyAndAttrs);
268314 }
269315
270316 static bool hasByValOrSRetArgs (
@@ -849,24 +895,21 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
849895 case FixupTy::Codes::ArgumentType: {
850896 // Argument is pass-by-value, but its type has likely been modified to
851897 // suit the target ABI convention.
852- auto oldArgTy =
853- fir::ReferenceType::get (oldArgTys[fixup.index - offset]);
898+ auto oldArgTy = oldArgTys[fixup.index - offset];
854899 // If type did not change, keep the original argument.
855900 if (fixupType == oldArgTy)
856901 break ;
857902
858903 auto newArg =
859904 func.front ().insertArgument (fixup.index , fixupType, loc);
860905 rewriter->setInsertionPointToStart (&func.front ());
861- auto mem = rewriter->create <fir::AllocaOp>(loc, fixupType);
862- rewriter->create <fir::StoreOp>(loc, newArg, mem);
863- auto cast = rewriter->create <fir::ConvertOp>(loc, oldArgTy, mem);
864- mlir::Value load = rewriter->create <fir::LoadOp>(loc, cast);
865- func.getArgument (fixup.index + 1 ).replaceAllUsesWith (load);
906+ mlir::Value bitcast = convertValueInMemory (loc, newArg, oldArgTy,
907+ /* inputMayBeBigger=*/ true );
908+ func.getArgument (fixup.index + 1 ).replaceAllUsesWith (bitcast);
866909 func.front ().eraseArgument (fixup.index + 1 );
867910 LLVM_DEBUG (llvm::dbgs ()
868- << " old argument: " << oldArgTy. getEleTy ()
869- << " , repl: " << load << " , new argument: "
911+ << " old argument: " << oldArgTy << " , repl: " << bitcast
912+ << " , new argument: "
870913 << func.getArgument (fixup.index ).getType () << ' \n ' );
871914 } break ;
872915 case FixupTy::Codes::CharPair: {
@@ -907,34 +950,43 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
907950 func.walk ([&](mlir::func::ReturnOp ret) {
908951 rewriter->setInsertionPoint (ret);
909952 auto oldOper = ret.getOperand (0 );
910- auto oldOperTy = fir::ReferenceType::get (oldOper.getType ());
911- auto mem =
912- rewriter->create <fir::AllocaOp>(loc, newResTys[fixup.index ]);
913- auto cast = rewriter->create <fir::ConvertOp>(loc, oldOperTy, mem);
914- rewriter->create <fir::StoreOp>(loc, oldOper, cast);
915- mlir::Value load = rewriter->create <fir::LoadOp>(loc, mem);
916- rewriter->create <mlir::func::ReturnOp>(loc, load);
953+ mlir::Value bitcast =
954+ convertValueInMemory (loc, oldOper, newResTys[fixup.index ],
955+ /* inputMayBeBigger=*/ false );
956+ rewriter->create <mlir::func::ReturnOp>(loc, bitcast);
917957 ret.erase ();
918958 });
919959 } break ;
920960 case FixupTy::Codes::Split: {
921961 // The FIR argument has been split into a pair of distinct arguments
922- // that are in juxtaposition to each other. (For COMPLEX value.)
962+ // that are in juxtaposition to each other. (For COMPLEX value or
963+ // derived type passed with VALUE in BIND(C) context).
923964 auto newArg =
924965 func.front ().insertArgument (fixup.index , fixupType, loc);
925966 if (fixup.second == 1 ) {
926967 rewriter->setInsertionPointToStart (&func.front ());
927- auto cplxTy = oldArgTys[fixup.index - offset - fixup.second ];
928- auto undef = rewriter->create <fir::UndefOp>(loc, cplxTy);
968+ mlir::Value firstArg = func.front ().getArgument (fixup.index - 1 );
969+ mlir::Type originalTy =
970+ oldArgTys[fixup.index - offset - fixup.second ];
971+ mlir::Type pairTy = originalTy;
972+ if (!fir::isa_complex (originalTy)) {
973+ pairTy = mlir::TupleType::get (
974+ originalTy.getContext (),
975+ mlir::TypeRange{firstArg.getType (), newArg.getType ()});
976+ }
977+ auto undef = rewriter->create <fir::UndefOp>(loc, pairTy);
929978 auto iTy = rewriter->getIntegerType (32 );
930979 auto zero = rewriter->getIntegerAttr (iTy, 0 );
931980 auto one = rewriter->getIntegerAttr (iTy, 1 );
932- auto cplx1 = rewriter->create <fir::InsertValueOp>(
933- loc, cplxTy, undef, func.front ().getArgument (fixup.index - 1 ),
934- rewriter->getArrayAttr (zero));
935- auto cplx = rewriter->create <fir::InsertValueOp>(
936- loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr (one));
937- func.getArgument (fixup.index + 1 ).replaceAllUsesWith (cplx);
981+ mlir::Value pair1 = rewriter->create <fir::InsertValueOp>(
982+ loc, pairTy, undef, firstArg, rewriter->getArrayAttr (zero));
983+ mlir::Value pair = rewriter->create <fir::InsertValueOp>(
984+ loc, pairTy, pair1, newArg, rewriter->getArrayAttr (one));
985+ // Cast local argument tuple to original type via memory if needed.
986+ if (pairTy != originalTy)
987+ pair = convertValueInMemory (loc, pair, originalTy,
988+ /* inputMayBeBigger=*/ true );
989+ func.getArgument (fixup.index + 1 ).replaceAllUsesWith (pair);
938990 func.front ().eraseArgument (fixup.index + 1 );
939991 offset++;
940992 }
0 commit comments