diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index 1675c15363868..d7ddb37480ebb 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { mlir::Block *getAllocaBlock(); /// Safely create a reference type to the type `eleTy`. - mlir::Type getRefType(mlir::Type eleTy); + mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false); /// Create a sequence of `eleTy` with `rank` dimensions of unknown size. mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1); diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h index 76e0aa352bcd9..8261c67e4559d 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) { fir::LLVMPointerType>(t); } +inline bool isa_volatile_ref_type(mlir::Type t) { + if (auto refTy = mlir::dyn_cast_or_null(t)) + return refTy.isVolatile(); + return false; +} + /// Is `t` a boxed type? inline bool isa_box_type(mlir::Type t) { return mlir::isa(t); diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index fd5bbbe44751f..c11758cfe9244 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -14,6 +14,7 @@ #define FIR_DIALECT_FIR_TYPES include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" include "flang/Optimizer/Dialect/FIRDialect.td" //===----------------------------------------------------------------------===// @@ -363,18 +364,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> { The type of a reference to an entity in memory. }]; - let parameters = (ins "mlir::Type":$eleTy); + let parameters = (ins + "mlir::Type":$eleTy, + DefaultValuedParameter<"bool", "false">:$isVol); let skipDefaultBuilders = 1; let builders = [ - TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ - return Base::get(elementType.getContext(), elementType); + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{ + return Base::get(elementType.getContext(), elementType, isVol); }]>, ]; let extraClassDeclaration = [{ mlir::Type getElementType() const { return getEleTy(); } + bool isVolatile() const { return (bool)getIsVol(); } + static llvm::StringRef getVolatileKeyword() { return "volatile"; } }]; let genVerifyDecl = 1; diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp index 226ba1e52c968..4ee28fbeb9a0c 100644 --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl { if (obj.attrs.test(Attrs::Value)) isValueAttr = true; // TODO: do we want an mlir::Attribute as well? if (obj.attrs.test(Attrs::Volatile)) { - TODO(loc, "VOLATILE in procedure interface"); addMLIRAttr(fir::getVolatileAttrName()); } // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index dc00e0b13f583..79906c81ecc68 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -223,8 +223,36 @@ class HlfirDesignatorBuilder { designatorNode, getConverter().getFoldingContext(), /*namedConstantSectionsAreAlwaysContiguous=*/false)) return fir::BoxType::get(resultValueType); + + bool isVolatile = false; + + // Check if the base type is volatile + if (partInfo.base.has_value()) { + mlir::Type baseType = partInfo.base.value().getType(); + isVolatile = fir::isa_volatile_ref_type(baseType); + } + + auto isVolatileSymbol = [&](const Fortran::semantics::Symbol &symbol) { + return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); + }; + + // Check if this should be a volatile reference + if constexpr (std::is_same_v, + Fortran::evaluate::SymbolRef>) { + if (isVolatileSymbol(designatorNode.get())) + isVolatile = true; + } else if constexpr (std::is_same_v, + Fortran::evaluate::Component>) { + if (isVolatileSymbol(designatorNode.GetLastSymbol())) + isVolatile = true; + } + + // If it's a reference to a ref, account for it + if (auto refTy = mlir::dyn_cast(resultValueType)) + resultValueType = refTy.getEleTy(); + // Other designators can be handled as raw addresses. - return fir::ReferenceType::get(resultValueType); + return fir::ReferenceType::get(resultValueType, isVolatile); } template @@ -414,10 +442,13 @@ class HlfirDesignatorBuilder { .Case([&](fir::SequenceType seqTy) -> mlir::Type { return fir::SequenceType::get(seqTy.getShape(), newEleTy); }) - .Case([&](auto t) -> mlir::Type { - using FIRT = decltype(t); - return FIRT::get(changeElementType(t.getEleTy(), newEleTy)); + .Case( + [&](auto t) -> mlir::Type { + using FIRT = decltype(t); + return FIRT::get(changeElementType(t.getEleTy(), newEleTy)); + }) + .Case([&](fir::ReferenceType refTy) -> mlir::Type { + return fir::ReferenceType::get(changeElementType(refTy.getEleTy(), newEleTy), refTy.isVolatile()); }) .Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; }); } @@ -1808,6 +1839,7 @@ class HlfirBuilder { auto &expr = std::get(iter); auto &baseOp = std::get(iter); std::string name = converter.getRecordTypeFieldName(sym); + const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType()); // Generate DesignateOp for the component. // The designator's result type is just a reference to the component type, @@ -1818,7 +1850,7 @@ class HlfirBuilder { assert(compType && "failed to retrieve component type"); mlir::Value compShape = designatorBuilder.genComponentShape(sym, compType); - mlir::Type designatorType = builder.getRefType(compType); + mlir::Type designatorType = builder.getRefType(compType, isVolatile); mlir::Type fieldElemType = hlfir::getFortranElementType(compType); llvm::SmallVector typeParams; diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index b3d440cedee07..cfae25f8fe4b9 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, return modOp.lookupSymbol(name); } -mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) { +mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) { assert(!mlir::isa(eleTy) && "cannot be a reference type"); - return fir::ReferenceType::get(eleTy); + return fir::ReferenceType::get(eleTy, isVolatile); } mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) { diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 1a31ca33e9465..cf8bb7eaddf70 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -809,7 +809,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) { } else if (fir::isRecordWithTypeParameters(eleTy)) { return fir::BoxType::get(eleTy); } - return fir::ReferenceType::get(eleTy); + const bool isVolatile = fir::isa_volatile_ref_type(variable.getType()); + return fir::ReferenceType::get(eleTy, isVolatile); } mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) { diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 2cb4cea58c2b0..2ef9fc79403c7 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -3218,6 +3218,7 @@ struct LoadOpConversion : public fir::FIROpConversion { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type llvmLoadTy = convertObjectType(load.getType()); + const bool isVolatile = fir::isa_volatile_ref_type(load.getMemref().getType()); if (auto boxTy = mlir::dyn_cast(load.getType())) { // fir.box is a special case because it is considered an ssa value in // fir, but it is lowered as a pointer to a descriptor. So @@ -3247,7 +3248,7 @@ struct LoadOpConversion : public fir::FIROpConversion { mlir::Value boxSize = computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter); auto memcpy = rewriter.create( - loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false); + loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile); if (std::optional optionalTag = load.getTbaa()) memcpy.setTBAATags(*optionalTag); @@ -3255,8 +3256,9 @@ struct LoadOpConversion : public fir::FIROpConversion { attachTBAATag(memcpy, boxTy, boxTy, nullptr); rewriter.replaceOp(load, newBoxStorage); } else { + auto memref = adaptor.getOperands()[0]; auto loadOp = rewriter.create( - load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs()); + load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile); if (std::optional optionalTag = load.getTbaa()) loadOp.setTBAATags(*optionalTag); else diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index dc0bee9b060c9..e2dc1ed3f3ecb 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -1057,18 +1057,39 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) { // ReferenceType //===----------------------------------------------------------------------===// -// `ref` `<` type `>` +// `ref` `<` type (`, volatile` $volatile^)? (`, async` $async^)? `>` mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) { - return parseTypeSingleton(parser); + if (parser.parseLess()) + return {}; + + mlir::Type eleTy; + if (parser.parseType(eleTy)) + return {}; + + bool isVolatile = false; + if (parser.parseOptionalComma()) { + if (parser.parseOptionalKeyword(getVolatileKeyword())) { + isVolatile = true; + } else { + return {}; + } + } + + if (parser.parseGreater()) + return {}; + return ReferenceType::get(eleTy, isVolatile); } void fir::ReferenceType::print(mlir::AsmPrinter &printer) const { - printer << "<" << getEleTy() << '>'; + printer << "<" << getEleTy(); + if (isVolatile()) + printer << ", volatile"; + printer << '>'; } llvm::LogicalResult fir::ReferenceType::verify( llvm::function_ref emitError, - mlir::Type eleTy) { + mlir::Type eleTy, bool isVolatile) { if (mlir::isa(eleTy)) return emitError() << "cannot build a reference to type: " << eleTy << '\n'; diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 8851a3a7187b9..4a3308ff4e747 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, auto nameAttr = builder.getStringAttr(uniq_name); mlir::Type inputType = memref.getType(); bool hasExplicitLbs = hasExplicitLowerBounds(shape); + if (fortran_attrs && mlir::isa(inputType) && + bitEnumContainsAny(fortran_attrs.getFlags(), + fir::FortranVariableFlagsEnum::fortran_volatile)) { + auto refType = mlir::cast(inputType); + inputType = fir::ReferenceType::get(refType.getEleTy(), true); + memref = builder.create(memref.getLoc(), inputType, memref); + } mlir::Type hlfirVariableType = getHLFIRVariableType(inputType, hasExplicitLbs); build(builder, result, {hlfirVariableType, inputType}, memref, shape, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index 96a3622f4afee..e22b3d224ca1f 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -1126,7 +1126,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern { builder.create(loc, flagSet, flagRef); mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); - mlir::Type returnRefTy = builder.getRefType(resultElemTy); + mlir::Type returnRefTy = builder.getRefType(resultElemTy, fir::isa_volatile_ref_type(flagRef.getType())); mlir::IndexType idxTy = builder.getIndexType(); for (unsigned int i = 0; i < rank; ++i) { @@ -1153,7 +1153,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern { auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, const mlir::Type &resultElemType, mlir::Value resultArr, mlir::Value index) { - mlir::Type resultRefTy = builder.getRefType(resultElemType); + mlir::Type resultRefTy = builder.getRefType(resultElemType, fir::isa_volatile_ref_type(resultArr.getType())); mlir::Value oneIdx = builder.createIntegerConstant(loc, builder.getIndexType(), 1); index = builder.create(loc, index, oneIdx); @@ -1162,8 +1162,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern { }; // Initialize the result + const bool isVolatile = fir::isa_volatile_ref_type(resultArr.getType()); mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); - mlir::Type resultRefTy = builder.getRefType(resultElemTy); + mlir::Type resultRefTy = builder.getRefType(resultElemTy, isVolatile); mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0); for (unsigned int i = 0; i < rank; ++i) {