Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::Block *getAllocaBlock();

/// Safely create a reference type to the type `eleTy`.
mlir::Type getRefType(mlir::Type eleTy);
mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);

/// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
fir::LLVMPointerType>(t);
}

inline bool isa_volatile_ref_type(mlir::Type t) {
if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
return refTy.isVolatile();
return false;
}

/// Is `t` a boxed type?
inline bool isa_box_type(mlir::Type t) {
return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
Expand Down
11 changes: 8 additions & 3 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FIR_DIALECT_FIR_TYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "flang/Optimizer/Dialect/FIRDialect.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -363,18 +364,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
The type of a reference to an entity in memory.
}];

let parameters = (ins "mlir::Type":$eleTy);
let parameters = (ins
"mlir::Type":$eleTy,
DefaultValuedParameter<"bool", "false">:$isVol);

let skipDefaultBuilders = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
return Base::get(elementType.getContext(), elementType, isVol);
}]>,
];

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
bool isVolatile() const { return (bool)getIsVol(); }
static llvm::StringRef getVolatileKeyword() { return "volatile"; }
}];

let genVerifyDecl = 1;
Expand Down
1 change: 0 additions & 1 deletion flang/lib/Lower/CallInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
if (obj.attrs.test(Attrs::Value))
isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
if (obj.attrs.test(Attrs::Volatile)) {
TODO(loc, "VOLATILE in procedure interface");
addMLIRAttr(fir::getVolatileAttrName());
}
// obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
Expand Down
44 changes: 38 additions & 6 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,36 @@ class HlfirDesignatorBuilder {
designatorNode, getConverter().getFoldingContext(),
/*namedConstantSectionsAreAlwaysContiguous=*/false))
return fir::BoxType::get(resultValueType);

bool isVolatile = false;

// Check if the base type is volatile
if (partInfo.base.has_value()) {
mlir::Type baseType = partInfo.base.value().getType();
isVolatile = fir::isa_volatile_ref_type(baseType);
}

auto isVolatileSymbol = [&](const Fortran::semantics::Symbol &symbol) {
return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE);
};

// Check if this should be a volatile reference
if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::SymbolRef>) {
if (isVolatileSymbol(designatorNode.get()))
isVolatile = true;
} else if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::Component>) {
if (isVolatileSymbol(designatorNode.GetLastSymbol()))
isVolatile = true;
}

// If it's a reference to a ref, account for it
if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
resultValueType = refTy.getEleTy();

// Other designators can be handled as raw addresses.
return fir::ReferenceType::get(resultValueType);
return fir::ReferenceType::get(resultValueType, isVolatile);
}

template <typename T>
Expand Down Expand Up @@ -414,10 +442,13 @@ class HlfirDesignatorBuilder {
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newEleTy);
})
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
fir::ClassType>([&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
.Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
[&](auto t) -> mlir::Type {
using FIRT = decltype(t);
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
})
.Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
return fir::ReferenceType::get(changeElementType(refTy.getEleTy(), newEleTy), refTy.isVolatile());
})
.Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
}
Expand Down Expand Up @@ -1808,6 +1839,7 @@ class HlfirBuilder {
auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
std::string name = converter.getRecordTypeFieldName(sym);
const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());

// Generate DesignateOp for the component.
// The designator's result type is just a reference to the component type,
Expand All @@ -1818,7 +1850,7 @@ class HlfirBuilder {
assert(compType && "failed to retrieve component type");
mlir::Value compShape =
designatorBuilder.genComponentShape(sym, compType);
mlir::Type designatorType = builder.getRefType(compType);
mlir::Type designatorType = builder.getRefType(compType, isVolatile);

mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
llvm::SmallVector<mlir::Value, 1> typeParams;
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
return modOp.lookupSymbol<fir::GlobalOp>(name);
}

mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
return fir::ReferenceType::get(eleTy);
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
} else if (fir::isRecordWithTypeParameters(eleTy)) {
return fir::BoxType::get(eleTy);
}
return fir::ReferenceType::get(eleTy);
const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3218,6 +3218,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type llvmLoadTy = convertObjectType(load.getType());
const bool isVolatile = fir::isa_volatile_ref_type(load.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
// fir.box is a special case because it is considered an ssa value in
// fir, but it is lowered as a pointer to a descriptor. So
Expand Down Expand Up @@ -3247,16 +3248,17 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);

if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
memcpy.setTBAATags(*optionalTag);
else
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
rewriter.replaceOp(load, newBoxStorage);
} else {
auto memref = adaptor.getOperands()[0];
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
loadOp.setTBAATags(*optionalTag);
else
Expand Down
29 changes: 25 additions & 4 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,18 +1057,39 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
// ReferenceType
//===----------------------------------------------------------------------===//

// `ref` `<` type `>`
// `ref` `<` type (`, volatile` $volatile^)? (`, async` $async^)? `>`
mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
return parseTypeSingleton<fir::ReferenceType>(parser);
if (parser.parseLess())
return {};

mlir::Type eleTy;
if (parser.parseType(eleTy))
return {};

bool isVolatile = false;
if (parser.parseOptionalComma()) {
if (parser.parseOptionalKeyword(getVolatileKeyword())) {
isVolatile = true;
} else {
return {};
}
}

if (parser.parseGreater())
return {};
return ReferenceType::get(eleTy, isVolatile);
}

void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getEleTy() << '>';
printer << "<" << getEleTy();
if (isVolatile())
printer << ", volatile";
printer << '>';
}

llvm::LogicalResult fir::ReferenceType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
mlir::Type eleTy, bool isVolatile) {
if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
ReferenceType, TypeDescType>(eleTy))
return emitError() << "cannot build a reference to type: " << eleTy << '\n';
Expand Down
7 changes: 7 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
auto nameAttr = builder.getStringAttr(uniq_name);
mlir::Type inputType = memref.getType();
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
bitEnumContainsAny(fortran_attrs.getFlags(),
fir::FortranVariableFlagsEnum::fortran_volatile)) {
auto refType = mlir::cast<fir::ReferenceType>(inputType);
inputType = fir::ReferenceType::get(refType.getEleTy(), true);
memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
}
mlir::Type hlfirVariableType =
getHLFIRVariableType(inputType, hasExplicitLbs);
build(builder, result, {hlfirVariableType, inputType}, memref, shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
mlir::Type returnRefTy = builder.getRefType(resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
mlir::IndexType idxTy = builder.getIndexType();

for (unsigned int i = 0; i < rank; ++i) {
Expand All @@ -1153,7 +1153,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
const mlir::Type &resultElemType, mlir::Value resultArr,
mlir::Value index) {
mlir::Type resultRefTy = builder.getRefType(resultElemType);
mlir::Type resultRefTy = builder.getRefType(resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
Expand All @@ -1162,8 +1162,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
};

// Initialize the result
const bool isVolatile = fir::isa_volatile_ref_type(resultArr.getType());
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
mlir::Type resultRefTy = builder.getRefType(resultElemTy);
mlir::Type resultRefTy = builder.getRefType(resultElemTy, isVolatile);
mlir::Value returnValue =
builder.createIntegerConstant(loc, resultElemTy, 0);
for (unsigned int i = 0; i < rank; ++i) {
Expand Down
Loading