1313#include " flang/Optimizer/Analysis/AliasAnalysis.h"
1414#include " flang/Optimizer/Builder/FIRBuilder.h"
1515#include " flang/Optimizer/Builder/HLFIRTools.h"
16+ #include " flang/Optimizer/Dialect/FIRType.h"
1617#include " flang/Optimizer/HLFIR/HLFIROps.h"
1718#include " flang/Optimizer/HLFIR/Passes.h"
1819#include " flang/Optimizer/OpenMP/Passes.h"
@@ -127,6 +128,121 @@ class InlineHLFIRAssignConversion
127128 }
128129};
129130
131+ class InlineCopyInConversion : public mlir ::OpRewritePattern<hlfir::CopyInOp> {
132+ public:
133+ using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern;
134+
135+ llvm::LogicalResult
136+ matchAndRewrite (hlfir::CopyInOp copyIn,
137+ mlir::PatternRewriter &rewriter) const override ;
138+ };
139+
140+ llvm::LogicalResult
141+ InlineCopyInConversion::matchAndRewrite (hlfir::CopyInOp copyIn,
142+ mlir::PatternRewriter &rewriter) const {
143+ fir::FirOpBuilder builder (rewriter, copyIn.getOperation ());
144+ mlir::Location loc = copyIn.getLoc ();
145+ hlfir::Entity inputVariable{copyIn.getVar ()};
146+ if (!fir::isa_trivial (inputVariable.getFortranElementType ()))
147+ return rewriter.notifyMatchFailure (copyIn,
148+ " CopyInOp's data type is not trivial" );
149+
150+ if (fir::isPointerType (inputVariable.getType ()))
151+ return rewriter.notifyMatchFailure (
152+ copyIn, " CopyInOp's input variable is a pointer" );
153+
154+ // There should be exactly one user of WasCopied - the corresponding
155+ // CopyOutOp.
156+ if (copyIn.getWasCopied ().getUses ().empty ())
157+ return rewriter.notifyMatchFailure (copyIn,
158+ " CopyInOp's WasCopied has no uses" );
159+ // The copy out should always be present, either to actually copy or just
160+ // deallocate memory.
161+ auto *copyOut =
162+ copyIn.getWasCopied ().getUsers ().begin ().getCurrent ().getUser ();
163+
164+ if (!mlir::isa<hlfir::CopyOutOp>(copyOut))
165+ return rewriter.notifyMatchFailure (copyIn,
166+ " CopyInOp has no direct CopyOut" );
167+
168+ // Only inline the copy_in when copy_out does not need to be done, i.e. in
169+ // case of intent(in).
170+ if (::llvm::cast<hlfir::CopyOutOp>(copyOut).getVar ())
171+ return rewriter.notifyMatchFailure (copyIn, " CopyIn needs a copy-out" );
172+
173+ inputVariable =
174+ hlfir::derefPointersAndAllocatables (loc, builder, inputVariable);
175+ mlir::Type resultAddrType = copyIn.getCopiedIn ().getType ();
176+ mlir::Value isContiguous =
177+ builder.create <fir::IsContiguousBoxOp>(loc, inputVariable);
178+ auto results =
179+ builder
180+ .genIfOp (loc, {resultAddrType, builder.getI1Type ()}, isContiguous,
181+ /* withElseRegion=*/ true )
182+ .genThen ([&]() {
183+ mlir::Value falseVal = builder.create <mlir::arith::ConstantOp>(
184+ loc, builder.getI1Type (), builder.getBoolAttr (false ));
185+ builder.create <fir::ResultOp>(
186+ loc, mlir::ValueRange{inputVariable, falseVal});
187+ })
188+ .genElse ([&] {
189+ auto [temp, cleanup] =
190+ hlfir::createTempFromMold (loc, builder, inputVariable);
191+ mlir::Value shape = hlfir::genShape (loc, builder, inputVariable);
192+ llvm::SmallVector<mlir::Value> extents =
193+ hlfir::getIndexExtents (loc, builder, shape);
194+ hlfir::LoopNest loopNest = hlfir::genLoopNest (
195+ loc, builder, extents, /* isUnordered=*/ true ,
196+ flangomp::shouldUseWorkshareLowering (copyIn));
197+ builder.setInsertionPointToStart (loopNest.body );
198+ auto elem = hlfir::getElementAt (loc, builder, inputVariable,
199+ loopNest.oneBasedIndices );
200+ elem = hlfir::loadTrivialScalar (loc, builder, elem);
201+ auto tempElem = hlfir::getElementAt (loc, builder, temp,
202+ loopNest.oneBasedIndices );
203+ builder.create <hlfir::AssignOp>(loc, elem, tempElem);
204+ builder.setInsertionPointAfter (loopNest.outerOp );
205+
206+ mlir::Value result;
207+ // Make sure the result is always a boxed array by boxing it
208+ // ourselves if need be.
209+ if (mlir::isa<fir::BaseBoxType>(temp.getType ())) {
210+ result = temp;
211+ } else {
212+ auto refTy =
213+ fir::ReferenceType::get (temp.getElementOrSequenceType ());
214+ auto refVal = builder.createConvert (loc, refTy, temp);
215+ result =
216+ builder.create <fir::EmboxOp>(loc, resultAddrType, refVal);
217+ }
218+
219+ builder.create <fir::ResultOp>(loc,
220+ mlir::ValueRange{result, cleanup});
221+ })
222+ .getResults ();
223+
224+ auto addr = results[0 ];
225+ auto needsCleanup = results[1 ];
226+
227+ builder.setInsertionPoint (copyOut);
228+ builder.genIfOp (loc, {}, needsCleanup, false ).genThen ([&] {
229+ auto boxAddr = builder.create <fir::BoxAddrOp>(loc, addr);
230+ auto heapType = fir::HeapType::get (fir::BoxValue (addr).getBaseTy ());
231+ auto heapVal = builder.createConvert (loc, heapType, boxAddr.getResult ());
232+ builder.create <fir::FreeMemOp>(loc, heapVal);
233+ });
234+ rewriter.eraseOp (copyOut);
235+
236+ auto tempBox = copyIn.getTempBox ();
237+
238+ rewriter.replaceOp (copyIn, {addr, builder.genNot (loc, isContiguous)});
239+
240+ // The TempBox is only needed for flang-rt calls which we're no longer
241+ // generating.
242+ rewriter.eraseOp (tempBox.getDefiningOp ());
243+ return mlir::success ();
244+ }
245+
130246class InlineHLFIRAssignPass
131247 : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> {
132248public:
@@ -140,6 +256,7 @@ class InlineHLFIRAssignPass
140256
141257 mlir::RewritePatternSet patterns (context);
142258 patterns.insert <InlineHLFIRAssignConversion>(context);
259+ patterns.insert <InlineCopyInConversion>(context);
143260
144261 if (mlir::failed (mlir::applyPatternsGreedily (
145262 getOperation (), std::move (patterns), config))) {
0 commit comments