Skip to content

Commit 5d43a6b

Browse files
committed
[flang][hlfir] add hlfir.eval_in_mem operation
1 parent 5eeb3fe commit 5d43a6b

File tree

8 files changed

+454
-20
lines changed

8 files changed

+454
-20
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AssociateOp;
3333
class ElementalOp;
3434
class ElementalOpInterface;
3535
class ElementalAddrOp;
36+
class EvaluateInMemoryOp;
3637
class YieldElementOp;
3738

3839
/// Is this a Fortran variable for which the defining op carrying the Fortran
@@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
398399
mlir::IRMapping &mapper,
399400
const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);
400401

402+
/// Create a new temporary with the shape and parameters of the provided
403+
/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
404+
/// operating on this new temporary. returns the temporary and whether the
405+
/// temporary is heap or stack allocated.
406+
std::pair<hlfir::Entity, bool>
407+
computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
408+
hlfir::EvaluateInMemoryOp evalInMem,
409+
mlir::Value shape, mlir::ValueRange typeParams);
410+
411+
// Clone the body of the hlfir.eval_in_mem operating on this the provided
412+
// storage. The provided storage must be a contiguous "raw" memory reference
413+
// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
414+
// No runtime check is inserted by this utility to enforce that. It is also
415+
// usually invalid to provide some storage that is already addressed directly
416+
// or indirectly inside the hlfir.eval_in_mem body.
417+
void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
418+
hlfir::EvaluateInMemoryOp, mlir::Value storage);
419+
401420
std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
402421
convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
403422
hlfir::Entity entity);

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
17551755
let hasVerifier = 1;
17561756
}
17571757

1758+
def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
1759+
RecursiveMemoryEffects, RecursivelySpeculatable,
1760+
SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
1761+
let summary = "Wrap an in-memory implementation that computes expression value";
1762+
let description = [{
1763+
Returns a Fortran expression value for which the computation is
1764+
implemented inside the region operating on the block argument which
1765+
is a raw memory reference corresponding to the expression type.
1766+
1767+
The shape and type parameters of the expressions are operands of the
1768+
operations.
1769+
1770+
The memory cannot escape the region, and it is not described how it is
1771+
allocated. This facilitates later elision of the temporary storage for the
1772+
expression evaluation if it can be evaluated in some other storage (like a
1773+
left-hand side variable).
1774+
1775+
Example:
1776+
1777+
A function returning an array can be represented as:
1778+
```
1779+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
1780+
%2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
1781+
^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
1782+
%3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
1783+
fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
1784+
}
1785+
```
1786+
}];
1787+
1788+
let arguments = (ins
1789+
Optional<fir_ShapeType>:$shape,
1790+
Variadic<AnyIntegerType>:$typeparams
1791+
);
1792+
1793+
let results = (outs hlfir_ExprType);
1794+
let regions = (region SizedRegion<1>:$body);
1795+
1796+
let assemblyFormat = [{
1797+
(`shape` $shape^)? (`typeparams` $typeparams^)?
1798+
attr-dict `:` functional-type(operands, results)
1799+
$body}];
1800+
1801+
let skipDefaultBuilders = 1;
1802+
let builders = [
1803+
OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
1804+
CArg<"mlir::ValueRange", "{}">:$typeparams)>
1805+
];
1806+
1807+
let extraClassDeclaration = [{
1808+
// Return block argument representing the memory where the expression
1809+
// is evaluated.
1810+
mlir::Value getMemory() {return getBody().getArgument(0);}
1811+
}];
1812+
1813+
let hasVerifier = 1;
1814+
}
1815+
1816+
17581817
#endif // FORTRAN_DIALECT_HLFIR_OPS

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
535535
if (mlir::isa<hlfir::ExprType>(entity.getType())) {
536536
if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
537537
return elemental.getShape();
538+
if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
539+
return evalInMem.getShape();
538540
return mlir::Value{};
539541
}
540542
if (auto varIface = entity.getIfVariableInterface())
@@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
642644
result.append(elemental.getTypeparams().begin(),
643645
elemental.getTypeparams().end());
644646
return;
647+
} else if (auto evalInMem =
648+
expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
649+
result.append(evalInMem.getTypeparams().begin(),
650+
evalInMem.getTypeparams().end());
651+
return;
645652
} else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
646653
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
647654
return;
@@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
13131320
};
13141321
return {hlfir::Entity{convertedRhs}, cleanup};
13151322
}
1323+
1324+
std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
1325+
mlir::Location loc, fir::FirOpBuilder &builder,
1326+
hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
1327+
mlir::ValueRange typeParams) {
1328+
llvm::StringRef tmpName{".tmp.expr_result"};
1329+
llvm::SmallVector<mlir::Value> extents =
1330+
hlfir::getIndexExtents(loc, builder, shape);
1331+
mlir::Type baseType =
1332+
hlfir::getFortranElementOrSequenceType(evalInMem.getType());
1333+
bool heapAllocated = fir::hasDynamicSize(baseType);
1334+
// Note: temporaries are stack allocated here when possible (do not require
1335+
// stack save/restore) because flang has always stack allocated function
1336+
// results.
1337+
mlir::Value temp = heapAllocated
1338+
? builder.createHeapTemporary(loc, baseType, tmpName,
1339+
extents, typeParams)
1340+
: builder.createTemporary(loc, baseType, tmpName,
1341+
extents, typeParams);
1342+
mlir::Value innerMemory = evalInMem.getMemory();
1343+
temp = builder.createConvert(loc, innerMemory.getType(), temp);
1344+
auto declareOp = builder.create<hlfir::DeclareOp>(
1345+
loc, temp, tmpName, shape, typeParams,
1346+
/*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1347+
computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
1348+
return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
1349+
}
1350+
1351+
void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
1352+
hlfir::EvaluateInMemoryOp evalInMem,
1353+
mlir::Value storage) {
1354+
mlir::Value innerMemory = evalInMem.getMemory();
1355+
mlir::Value storageCast =
1356+
builder.createConvert(loc, innerMemory.getType(), storage);
1357+
mlir::IRMapping mapper;
1358+
mapper.map(innerMemory, storageCast);
1359+
for (auto &op : evalInMem.getBody().front().without_terminator())
1360+
builder.clone(op, mapper);
1361+
return;
1362+
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
333333
p << "real";
334334
}
335335
}
336+
template <typename Op>
337+
static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
338+
unsigned numLenParam) {
339+
if (mlir::isa<fir::CharacterType>(elementType)) {
340+
if (numLenParam != 1)
341+
return op.emitOpError("must be provided one length parameter when the "
342+
"result is a character");
343+
} else if (fir::isRecordWithTypeParameters(elementType)) {
344+
if (numLenParam !=
345+
mlir::cast<fir::RecordType>(elementType).getNumLenParams())
346+
return op.emitOpError("must be provided the same number of length "
347+
"parameters as in the result derived type");
348+
} else if (numLenParam != 0) {
349+
return op.emitOpError(
350+
"must not be provided length parameters if the result "
351+
"type does not have length parameters");
352+
}
353+
return mlir::success();
354+
}
336355

337356
llvm::LogicalResult hlfir::DesignateOp::verify() {
338357
mlir::Type memrefType = getMemref().getType();
@@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
462481
return emitOpError("shape must be a fir.shape or fir.shapeshift with "
463482
"the rank of the result");
464483
}
465-
auto numLenParam = getTypeparams().size();
466-
if (mlir::isa<fir::CharacterType>(outputElementType)) {
467-
if (numLenParam != 1)
468-
return emitOpError("must be provided one length parameter when the "
469-
"result is a character");
470-
} else if (fir::isRecordWithTypeParameters(outputElementType)) {
471-
if (numLenParam !=
472-
mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
473-
return emitOpError("must be provided the same number of length "
474-
"parameters as in the result derived type");
475-
} else if (numLenParam != 0) {
476-
return emitOpError("must not be provided length parameters if the result "
477-
"type does not have length parameters");
478-
}
484+
if (auto res =
485+
verifyTypeparams(*this, outputElementType, getTypeparams().size());
486+
failed(res))
487+
return res;
479488
}
480489
return mlir::success();
481490
}
@@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
19891998
return mlir::success();
19901999
}
19912000

2001+
//===----------------------------------------------------------------------===//
2002+
// EvaluateInMemoryOp
2003+
//===----------------------------------------------------------------------===//
2004+
2005+
void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
2006+
mlir::OperationState &odsState,
2007+
mlir::Type resultType, mlir::Value shape,
2008+
mlir::ValueRange typeparams) {
2009+
odsState.addTypes(resultType);
2010+
if (shape)
2011+
odsState.addOperands(shape);
2012+
odsState.addOperands(typeparams);
2013+
odsState.addAttribute(
2014+
getOperandSegmentSizeAttr(),
2015+
builder.getDenseI32ArrayAttr(
2016+
{shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
2017+
mlir::Region *bodyRegion = odsState.addRegion();
2018+
bodyRegion->push_back(new mlir::Block{});
2019+
mlir::Type memType = fir::ReferenceType::get(
2020+
hlfir::getFortranElementOrSequenceType(resultType));
2021+
bodyRegion->front().addArgument(memType, odsState.location);
2022+
EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
2023+
}
2024+
2025+
llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
2026+
unsigned shapeRank = 0;
2027+
if (mlir::Value shape = getShape())
2028+
if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
2029+
shapeRank = shapeTy.getRank();
2030+
auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
2031+
if (shapeRank != exprType.getRank())
2032+
return emitOpError("`shape` rank must match the result rank");
2033+
mlir::Type elementType = exprType.getElementType();
2034+
if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
2035+
failed(res))
2036+
return res;
2037+
return mlir::success();
2038+
}
2039+
19922040
#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
19932041
#define GET_OP_CLASSES
19942042
#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,26 @@ struct CharExtremumOpConversion
905905
}
906906
};
907907

908+
struct EvaluateInMemoryOpConversion
909+
: public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
910+
using mlir::OpConversionPattern<
911+
hlfir::EvaluateInMemoryOp>::OpConversionPattern;
912+
explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
913+
: mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
914+
llvm::LogicalResult
915+
matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
916+
mlir::ConversionPatternRewriter &rewriter) const override {
917+
mlir::Location loc = evalInMemOp->getLoc();
918+
fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
919+
auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
920+
loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
921+
mlir::Value bufferizedExpr =
922+
packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
923+
rewriter.replaceOp(evalInMemOp, bufferizedExpr);
924+
return mlir::success();
925+
}
926+
};
927+
908928
class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
909929
public:
910930
void runOnOperation() override {
@@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
918938
auto module = this->getOperation();
919939
auto *context = &getContext();
920940
mlir::RewritePatternSet patterns(context);
921-
patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
922-
AssociateOpConversion, CharExtremumOpConversion,
923-
ConcatOpConversion, DestroyOpConversion,
924-
ElementalOpConversion, EndAssociateOpConversion,
925-
NoReassocOpConversion, SetLengthOpConversion,
926-
ShapeOfOpConversion, GetLengthOpConversion>(context);
941+
patterns
942+
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
943+
AssociateOpConversion, CharExtremumOpConversion,
944+
ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
945+
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
946+
NoReassocOpConversion, SetLengthOpConversion,
947+
ShapeOfOpConversion, GetLengthOpConversion>(context);
927948
mlir::ConversionTarget target(*context);
928949
// Note that YieldElementOp is not marked as an illegal operation.
929950
// It must be erased by its parent converter and there is no explicit

0 commit comments

Comments
 (0)