Skip to content

Commit 7820693

Browse files
committed
[flang] Inline hlfir.copy_in for trivial types
hlfir.copy_in implements copying non-contiguous array slices for functions that take in arrays required to be contiguous through a flang-rt function that calls memcpy/memmove separately on each element. For large arrays of trivial types, this can incur considerable overhead compared to a plain copy loop that is better able to take advantage of hardware pipelines. To address that, extend the InlineHLFIRAssign optimisation pass with a new pattern for inlining hlfir.copy_in operations for trivial types. For the time being, the pattern is only applied in cases where the copy-in does not require a corresponding copy-out, such as when the function being called declares the array parameter as intent(in). Applying this optimisation reduces the runtime of thornado-mini's DeleptonizationProblem by a factor of about 1/3rd. Signed-off-by: Kajetan Puchalski <[email protected]>
1 parent 6421248 commit 7820693

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
1414
#include "flang/Optimizer/Builder/FIRBuilder.h"
1515
#include "flang/Optimizer/Builder/HLFIRTools.h"
16+
#include "flang/Optimizer/Dialect/FIRType.h"
1617
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1718
#include "flang/Optimizer/HLFIR/Passes.h"
1819
#include "flang/Optimizer/OpenMP/Passes.h"
@@ -127,6 +128,121 @@ class InlineHLFIRAssignConversion
127128
}
128129
};
129130

131+
class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
132+
public:
133+
using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern;
134+
135+
llvm::LogicalResult
136+
matchAndRewrite(hlfir::CopyInOp copyIn,
137+
mlir::PatternRewriter &rewriter) const override;
138+
};
139+
140+
llvm::LogicalResult
141+
InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn,
142+
mlir::PatternRewriter &rewriter) const {
143+
fir::FirOpBuilder builder(rewriter, copyIn.getOperation());
144+
mlir::Location loc = copyIn.getLoc();
145+
hlfir::Entity inputVariable{copyIn.getVar()};
146+
if (!fir::isa_trivial(inputVariable.getFortranElementType()))
147+
return rewriter.notifyMatchFailure(copyIn,
148+
"CopyInOp's data type is not trivial");
149+
150+
if (fir::isPointerType(inputVariable.getType()))
151+
return rewriter.notifyMatchFailure(
152+
copyIn, "CopyInOp's input variable is a pointer");
153+
154+
// There should be exactly one user of WasCopied - the corresponding
155+
// CopyOutOp.
156+
if (copyIn.getWasCopied().getUses().empty())
157+
return rewriter.notifyMatchFailure(copyIn,
158+
"CopyInOp's WasCopied has no uses");
159+
// The copy out should always be present, either to actually copy or just
160+
// deallocate memory.
161+
auto *copyOut =
162+
copyIn.getWasCopied().getUsers().begin().getCurrent().getUser();
163+
164+
if (!mlir::isa<hlfir::CopyOutOp>(copyOut))
165+
return rewriter.notifyMatchFailure(copyIn,
166+
"CopyInOp has no direct CopyOut");
167+
168+
// Only inline the copy_in when copy_out does not need to be done, i.e. in
169+
// case of intent(in).
170+
if (::llvm::cast<hlfir::CopyOutOp>(copyOut).getVar())
171+
return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out");
172+
173+
inputVariable =
174+
hlfir::derefPointersAndAllocatables(loc, builder, inputVariable);
175+
mlir::Type resultAddrType = copyIn.getCopiedIn().getType();
176+
mlir::Value isContiguous =
177+
builder.create<fir::IsContiguousBoxOp>(loc, inputVariable);
178+
auto results =
179+
builder
180+
.genIfOp(loc, {resultAddrType, builder.getI1Type()}, isContiguous,
181+
/*withElseRegion=*/true)
182+
.genThen([&]() {
183+
mlir::Value falseVal = builder.create<mlir::arith::ConstantOp>(
184+
loc, builder.getI1Type(), builder.getBoolAttr(false));
185+
builder.create<fir::ResultOp>(
186+
loc, mlir::ValueRange{inputVariable, falseVal});
187+
})
188+
.genElse([&] {
189+
auto [temp, cleanup] =
190+
hlfir::createTempFromMold(loc, builder, inputVariable);
191+
mlir::Value shape = hlfir::genShape(loc, builder, inputVariable);
192+
llvm::SmallVector<mlir::Value> extents =
193+
hlfir::getIndexExtents(loc, builder, shape);
194+
hlfir::LoopNest loopNest = hlfir::genLoopNest(
195+
loc, builder, extents, /*isUnordered=*/true,
196+
flangomp::shouldUseWorkshareLowering(copyIn));
197+
builder.setInsertionPointToStart(loopNest.body);
198+
auto elem = hlfir::getElementAt(loc, builder, inputVariable,
199+
loopNest.oneBasedIndices);
200+
elem = hlfir::loadTrivialScalar(loc, builder, elem);
201+
auto tempElem = hlfir::getElementAt(loc, builder, temp,
202+
loopNest.oneBasedIndices);
203+
builder.create<hlfir::AssignOp>(loc, elem, tempElem);
204+
builder.setInsertionPointAfter(loopNest.outerOp);
205+
206+
mlir::Value result;
207+
// Make sure the result is always a boxed array by boxing it
208+
// ourselves if need be.
209+
if (mlir::isa<fir::BaseBoxType>(temp.getType())) {
210+
result = temp;
211+
} else {
212+
auto refTy =
213+
fir::ReferenceType::get(temp.getElementOrSequenceType());
214+
auto refVal = builder.createConvert(loc, refTy, temp);
215+
result =
216+
builder.create<fir::EmboxOp>(loc, resultAddrType, refVal);
217+
}
218+
219+
builder.create<fir::ResultOp>(loc,
220+
mlir::ValueRange{result, cleanup});
221+
})
222+
.getResults();
223+
224+
auto addr = results[0];
225+
auto needsCleanup = results[1];
226+
227+
builder.setInsertionPoint(copyOut);
228+
builder.genIfOp(loc, {}, needsCleanup, false).genThen([&] {
229+
auto boxAddr = builder.create<fir::BoxAddrOp>(loc, addr);
230+
auto heapType = fir::HeapType::get(fir::BoxValue(addr).getBaseTy());
231+
auto heapVal = builder.createConvert(loc, heapType, boxAddr.getResult());
232+
builder.create<fir::FreeMemOp>(loc, heapVal);
233+
});
234+
rewriter.eraseOp(copyOut);
235+
236+
auto tempBox = copyIn.getTempBox();
237+
238+
rewriter.replaceOp(copyIn, {addr, builder.genNot(loc, isContiguous)});
239+
240+
// The TempBox is only needed for flang-rt calls which we're no longer
241+
// generating.
242+
rewriter.eraseOp(tempBox.getDefiningOp());
243+
return mlir::success();
244+
}
245+
130246
class InlineHLFIRAssignPass
131247
: public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> {
132248
public:
@@ -140,6 +256,7 @@ class InlineHLFIRAssignPass
140256

141257
mlir::RewritePatternSet patterns(context);
142258
patterns.insert<InlineHLFIRAssignConversion>(context);
259+
patterns.insert<InlineCopyInConversion>(context);
143260

144261
if (mlir::failed(mlir::applyPatternsGreedily(
145262
getOperation(), std::move(patterns), config))) {

0 commit comments

Comments
 (0)