Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
this->getTypeConverter());
}

const mlir::DataLayout &getDataLayout() const {
return lowerTy().getDataLayout();
}

void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType, mlir::Type accessFIRType,
mlir::LLVM::GEPOp gep) const {
Expand Down
34 changes: 34 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===-- LLVMInsertChainFolder.h -- insertvalue chain folder ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helper to fold LLVM dialect llvm.insertvalue chain representing constants
// into an Attribute representation.
// This sits in Flang because it is incomplete and tailored for flang needs.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/LogicalResult.h"

namespace mlir {
class Attribute;
class OpBuilder;
class Value;
} // namespace mlir

namespace fir {

/// Attempt to fold an llvm.insertvalue chain into an attribute representation
/// suitable as llvm.constant operand. The returned value will be a null pointer
/// if this is not an llvm.insertvalue result pr if the chain is not a constant,
/// or cannot be represented as an Attribute. The operations are not deleted,
/// but some llvm.insertvalue value operands may be folded with the builder on
/// the way.
llvm::FailureOr<mlir::Attribute>
tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
mlir::OpBuilder &builder);
} // namespace fir
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,11 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoMemoryEffect]> {
$seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
/// Is this insert_on_range inserting on all the values of the result type?
bool isFullRange();
}];

let hasVerifier = 1;
}

Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_flang_library(FIRCodeGen
CodeGen.cpp
CodeGenOpenMP.cpp
FIROpPatterns.cpp
LLVMInsertChainFolder.cpp
LowerRepackArrays.cpp
PreCGRewrite.cpp
TBAABuilder.cpp
Expand Down
102 changes: 52 additions & 50 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
#include "flang/Optimizer/CodeGen/FIROpPatterns.h"
#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
Expand Down Expand Up @@ -1043,33 +1044,23 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
static mlir::Value
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter) {
// Note that we cannot use something like
// mlir::LLVM::getPrimitiveTypeSizeInBits() for the element type here. For
// example, it returns 10 bytes for mlir::Float80Type for targets where it
// occupies 16 bytes. Proper solution is probably to use
// mlir::DataLayout::getTypeABIAlignment(), but DataLayout is not being set
// yet (see llvm-project#57230). For the time being use the '(intptr_t)((type
// *)0 + 1)' trick for all types. The generated instructions are optimized
// into constant by the first pass of InstCombine, so it should not be a
// performance issue.
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext());
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
auto gep = rewriter.create<mlir::LLVM::GEPOp>(
loc, llvmPtrTy, llvmObjectType, nullPtr,
llvm::ArrayRef<mlir::LLVM::GEPArg>{1});
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, gep);
mlir::ConversionPatternRewriter &rewriter,
const mlir::DataLayout &dataLayout) {
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
std::int64_t distance = llvm::alignTo(size, alignment);
return genConstantIndex(loc, idxTy, rewriter, distance);
}

/// Return value of the stride in bytes between adjacent elements
/// of LLVM type \p llTy. The result is returned as a value of
/// \p idxTy integer type.
static mlir::Value
genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) {
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
const mlir::DataLayout &dataLayout) {
// Create a pointer type and use computeElementDistance().
return computeElementDistance(loc, llTy, idxTy, rewriter);
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
}

namespace {
Expand Down Expand Up @@ -1111,7 +1102,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type llTy) const {
return computeElementDistance(loc, llTy, idxTy, rewriter);
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
}
};
} // namespace
Expand Down Expand Up @@ -1323,8 +1314,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
fir::CharacterType charTy,
mlir::ValueRange lenParams) const {
auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Value size =
genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy));
mlir::Value size = genTypeStrideInBytes(
loc, i64Ty, rewriter, this->convertType(charTy), this->getDataLayout());
if (charTy.hasConstantLen())
return size; // Length accounted for in the genTypeStrideInBytes GEP.
// Otherwise, multiply the single character size by the length.
Expand All @@ -1338,6 +1329,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
mlir::Type boxEleTy, mlir::ValueRange lenParams = {}) const {
const mlir::DataLayout &dataLayout = this->getDataLayout();
auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
if (auto eleTy = fir::dyn_cast_ptrEleTy(boxEleTy))
boxEleTy = eleTy;
Expand All @@ -1354,18 +1346,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
mlir::dyn_cast<fir::LogicalType>(boxEleTy) || fir::isa_real(boxEleTy) ||
fir::isa_complex(boxEleTy))
return {genTypeStrideInBytes(loc, i64Ty, rewriter,
this->convertType(boxEleTy)),
this->convertType(boxEleTy), dataLayout),
typeCodeVal};
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(boxEleTy))
return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
typeCodeVal};
if (fir::isa_ref_type(boxEleTy)) {
auto ptrTy = ::getLlvmPtrType(rewriter.getContext());
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy, dataLayout),
typeCodeVal};
}
if (mlir::isa<fir::RecordType>(boxEleTy))
return {genTypeStrideInBytes(loc, i64Ty, rewriter,
this->convertType(boxEleTy)),
this->convertType(boxEleTy), dataLayout),
typeCodeVal};
fir::emitFatalError(loc, "unhandled type in fir.box code generation");
}
Expand Down Expand Up @@ -1909,8 +1902,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
if (hasSubcomp) {
// We have a subcomponent. The step value needs to be the number of
// bytes per element (which is a derived type).
prevDimByteStride =
genTypeStrideInBytes(loc, i64Ty, rewriter, convertType(seqEleTy));
prevDimByteStride = genTypeStrideInBytes(
loc, i64Ty, rewriter, convertType(seqEleTy), getDataLayout());
} else if (hasSubstr) {
// We have a substring. The step value needs to be the number of bytes
// per CHARACTER element.
Expand Down Expand Up @@ -2420,15 +2413,39 @@ struct InsertOnRangeOpConversion
doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

llvm::SmallVector<std::int64_t> dims;
auto type = adaptor.getOperands()[0].getType();
auto arrayType = adaptor.getSeq().getType();

// Iteratively extract the array dimensions from the type.
llvm::SmallVector<std::int64_t> dims;
mlir::Type type = arrayType;
while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
dims.push_back(t.getNumElements());
type = t.getElementType();
}

// Avoid generating long insert chain that are very slow to fold back
// (which is required in globals when later generating LLVM IR). Attempt to
// fold the inserted element value to an attribute and build an ArrayAttr
// for the resulting array.
if (range.isFullRange()) {
llvm::FailureOr<mlir::Attribute> cst =
fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter);
if (llvm::succeeded(cst)) {
mlir::Attribute dimVal = *cst;
for (auto dim : llvm::reverse(dims)) {
// Use std::vector in case the number of elements is big.
std::vector<mlir::Attribute> elements(dim, dimVal);
dimVal = mlir::ArrayAttr::get(range.getContext(), elements);
}
// Replace insert chain with constant.
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType,
dimVal);
return mlir::success();
}
}

// The inserted value cannot be folded to an attribute, turn the
// insert_range into an llvm.insertvalue chain.
llvm::SmallVector<std::int64_t> lBounds;
llvm::SmallVector<std::int64_t> uBounds;

Expand All @@ -2442,8 +2459,8 @@ struct InsertOnRangeOpConversion

auto &subscripts = lBounds;
auto loc = range.getLoc();
mlir::Value lastOp = adaptor.getOperands()[0];
mlir::Value insertVal = adaptor.getOperands()[1];
mlir::Value lastOp = adaptor.getSeq();
mlir::Value insertVal = adaptor.getVal();

while (subscripts != uBounds) {
lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
Expand Down Expand Up @@ -3139,7 +3156,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
// initialization is on the full range.
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
for (auto insertOp : insertOnRangeOps) {
if (isFullRange(insertOp.getCoor(), insertOp.getType())) {
if (insertOp.isFullRange()) {
auto seqTyAttr = convertType(insertOp.getType());
auto *op = insertOp.getVal().getDefiningOp();
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
Expand Down Expand Up @@ -3169,22 +3186,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
return mlir::success();
}

bool isFullRange(mlir::DenseIntElementsAttr indexes,
fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
return false;
auto cur_index = indexes.value_begin<int64_t>();
for (unsigned i = 0; i < indexes.size(); i += 2) {
if (*(cur_index++) != 0)
return false;
if (*(cur_index++) != extents[i / 2] - 1)
return false;
}
return true;
}

// TODO: String comparaison should be avoided. Replace linkName with an
// TODO: String comparisons should be avoided. Replace linkName with an
// enumeration.
mlir::LLVM::Linkage
convertLinkage(std::optional<llvm::StringRef> optLinkage) const {
Expand Down Expand Up @@ -3604,8 +3606,8 @@ struct CopyOpConversion : public fir::FIROpConversion<fir::CopyOp> {
mlir::Value llvmDestination = adaptor.getDestination();
mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Type copyTy = fir::unwrapRefType(copy.getSource().getType());
mlir::Value copySize =
genTypeStrideInBytes(loc, i64Ty, rewriter, convertType(copyTy));
mlir::Value copySize = genTypeStrideInBytes(
loc, i64Ty, rewriter, convertType(copyTy), getDataLayout());

mlir::LLVM::AliasAnalysisOpInterface newOp;
if (copy.getNoOverlap())
Expand Down
Loading