Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
this->getTypeConverter());
}

const mlir::DataLayout &getDataLayout() const {
return lowerTy().getDataLayout();
}

void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType, mlir::Type accessFIRType,
mlir::LLVM::GEPOp gep) const {
Expand Down
34 changes: 34 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===-- LLVMInsertChainFolder.h -- insertvalue chain folder ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper to fold LLVM dialect llvm.insertvalue chain representing constants
// into an Attribute representation.
// This sits in Flang because it is incomplete and tailored for flang needs.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/LogicalResult.h"

namespace mlir {
class Attribute;
class OpBuilder;
class Value;
} // namespace mlir

namespace fir {

/// Attempt to fold an llvm.insertvalue chain into an attribute representation
/// suitable as llvm.constant operand. The returned value will be llvm::Failure
/// if this is not an llvm.insertvalue result or if the chain is not a constant,
/// or cannot be represented as an Attribute. The operations are not deleted,
/// but some llvm.insertvalue value operands may be folded with the builder on
/// the way.
llvm::FailureOr<mlir::Attribute>
tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
mlir::OpBuilder &builder);
} // namespace fir
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,11 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoMemoryEffect]> {
$seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
/// Is this insert_on_range inserting on all the values of the result type?
bool isFullRange();
}];

let hasVerifier = 1;
}

Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_flang_library(FIRCodeGen
CodeGen.cpp
CodeGenOpenMP.cpp
FIROpPatterns.cpp
LLVMInsertChainFolder.cpp
LowerRepackArrays.cpp
PreCGRewrite.cpp
TBAABuilder.cpp
Expand Down
102 changes: 52 additions & 50 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
#include "flang/Optimizer/CodeGen/FIROpPatterns.h"
#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
Expand Down Expand Up @@ -1043,33 +1044,23 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
static mlir::Value
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter) {
// Note that we cannot use something like
// mlir::LLVM::getPrimitiveTypeSizeInBits() for the element type here. For
// example, it returns 10 bytes for mlir::Float80Type for targets where it
// occupies 16 bytes. Proper solution is probably to use
// mlir::DataLayout::getTypeABIAlignment(), but DataLayout is not being set
// yet (see llvm-project#57230). For the time being use the '(intptr_t)((type
// *)0 + 1)' trick for all types. The generated instructions are optimized
// into constant by the first pass of InstCombine, so it should not be a
// performance issue.
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext());
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
auto gep = rewriter.create<mlir::LLVM::GEPOp>(
loc, llvmPtrTy, llvmObjectType, nullPtr,
llvm::ArrayRef<mlir::LLVM::GEPArg>{1});
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, gep);
mlir::ConversionPatternRewriter &rewriter,
const mlir::DataLayout &dataLayout) {
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
std::int64_t distance = llvm::alignTo(size, alignment);
return genConstantIndex(loc, idxTy, rewriter, distance);
}

/// Return value of the stride in bytes between adjacent elements
/// of LLVM type \p llTy. The result is returned as a value of
/// \p idxTy integer type.
static mlir::Value
genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) {
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
const mlir::DataLayout &dataLayout) {
// Create a pointer type and use computeElementDistance().
return computeElementDistance(loc, llTy, idxTy, rewriter);
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
}

namespace {
Expand Down Expand Up @@ -1111,7 +1102,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) const {
return computeElementDistance(loc, llTy, idxTy, rewriter);
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
}
};
} // namespace
Expand Down Expand Up @@ -1323,8 +1314,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
fir::CharacterType charTy,
mlir::ValueRange lenParams) const {
auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Value size =
genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy));
mlir::Value size = genTypeStrideInBytes(
loc, i64Ty, rewriter, this->convertType(charTy), this->getDataLayout());
if (charTy.hasConstantLen())
return size; // Length accounted for in the genTypeStrideInBytes GEP.
// Otherwise, multiply the single character size by the length.
Expand All @@ -1338,6 +1329,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
mlir::Type boxEleTy, mlir::ValueRange lenParams = {}) const {
const mlir::DataLayout &dataLayout = this->getDataLayout();
auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
if (auto eleTy = fir::dyn_cast_ptrEleTy(boxEleTy))
boxEleTy = eleTy;
Expand All @@ -1354,18 +1346,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
mlir::dyn_cast<fir::LogicalType>(boxEleTy) || fir::isa_real(boxEleTy) ||
fir::isa_complex(boxEleTy))
return {genTypeStrideInBytes(loc, i64Ty, rewriter,
this->convertType(boxEleTy)),
this->convertType(boxEleTy), dataLayout),
typeCodeVal};
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(boxEleTy))
return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
typeCodeVal};
if (fir::isa_ref_type(boxEleTy)) {
auto ptrTy = ::getLlvmPtrType(rewriter.getContext());
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy, dataLayout),
typeCodeVal};
}
if (mlir::isa<fir::RecordType>(boxEleTy))
return {genTypeStrideInBytes(loc, i64Ty, rewriter,
this->convertType(boxEleTy)),
this->convertType(boxEleTy), dataLayout),
typeCodeVal};
fir::emitFatalError(loc, "unhandled type in fir.box code generation");
}
Expand Down Expand Up @@ -1909,8 +1902,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
if (hasSubcomp) {
// We have a subcomponent. The step value needs to be the number of
// bytes per element (which is a derived type).
prevDimByteStride =
genTypeStrideInBytes(loc, i64Ty, rewriter, convertType(seqEleTy));
prevDimByteStride = genTypeStrideInBytes(
loc, i64Ty, rewriter, convertType(seqEleTy), getDataLayout());
} else if (hasSubstr) {
// We have a substring. The step value needs to be the number of bytes
// per CHARACTER element.
Expand Down Expand Up @@ -2420,15 +2413,39 @@ struct InsertOnRangeOpConversion
doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

llvm::SmallVector<std::int64_t> dims;
auto type = adaptor.getOperands()[0].getType();
auto arrayType = adaptor.getSeq().getType();

// Iteratively extract the array dimensions from the type.
llvm::SmallVector<std::int64_t> dims;
mlir::Type type = arrayType;
while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
dims.push_back(t.getNumElements());
type = t.getElementType();
}

// Avoid generating long insert chain that are very slow to fold back
// (which is required in globals when later generating LLVM IR). Attempt to
// fold the inserted element value to an attribute and build an ArrayAttr
// for the resulting array.
if (range.isFullRange()) {
llvm::FailureOr<mlir::Attribute> cst =
fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter);
if (llvm::succeeded(cst)) {
mlir::Attribute dimVal = *cst;
for (auto dim : llvm::reverse(dims)) {
// Use std::vector in case the number of elements is big.
std::vector<mlir::Attribute> elements(dim, dimVal);
dimVal = mlir::ArrayAttr::get(range.getContext(), elements);
}
// Replace insert chain with constant.
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType,
dimVal);
return mlir::success();
}
}

// The inserted value cannot be folded to an attribute, turn the
// insert_range into an llvm.insertvalue chain.
llvm::SmallVector<std::int64_t> lBounds;
llvm::SmallVector<std::int64_t> uBounds;

Expand All @@ -2442,8 +2459,8 @@ struct InsertOnRangeOpConversion

auto &subscripts = lBounds;
auto loc = range.getLoc();
mlir::Value lastOp = adaptor.getOperands()[0];
mlir::Value insertVal = adaptor.getOperands()[1];
mlir::Value lastOp = adaptor.getSeq();
mlir::Value insertVal = adaptor.getVal();

while (subscripts != uBounds) {
lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
Expand Down Expand Up @@ -3139,7 +3156,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
// initialization is on the full range.
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
for (auto insertOp : insertOnRangeOps) {
if (isFullRange(insertOp.getCoor(), insertOp.getType())) {
if (insertOp.isFullRange()) {
auto seqTyAttr = convertType(insertOp.getType());
auto *op = insertOp.getVal().getDefiningOp();
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
Expand Down Expand Up @@ -3169,22 +3186,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
return mlir::success();
}

bool isFullRange(mlir::DenseIntElementsAttr indexes,
fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
return false;
auto cur_index = indexes.value_begin<int64_t>();
for (unsigned i = 0; i < indexes.size(); i += 2) {
if (*(cur_index++) != 0)
return false;
if (*(cur_index++) != extents[i / 2] - 1)
return false;
}
return true;
}

// TODO: String comparaison should be avoided. Replace linkName with an
// TODO: String comparisons should be avoided. Replace linkName with an
// enumeration.
mlir::LLVM::Linkage
convertLinkage(std::optional<llvm::StringRef> optLinkage) const {
Expand Down Expand Up @@ -3604,8 +3606,8 @@ struct CopyOpConversion : public fir::FIROpConversion<fir::CopyOp> {
mlir::Value llvmDestination = adaptor.getDestination();
mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Type copyTy = fir::unwrapRefType(copy.getSource().getType());
mlir::Value copySize =
genTypeStrideInBytes(loc, i64Ty, rewriter, convertType(copyTy));
mlir::Value copySize = genTypeStrideInBytes(
loc, i64Ty, rewriter, convertType(copyTy), getDataLayout());

mlir::LLVM::AliasAnalysisOpInterface newOp;
if (copy.getNoOverlap())
Expand Down
Loading