diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index ac80873dc374f..6390a3a926fba 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -62,6 +62,15 @@ class Entity : public mlir::Value { bool isProcedurePointer() const { return hlfir::isFortranProcedurePointerType(getType()); } + bool isVolatile() const { + if (auto iface = getIfVariableInterface()) { + if (auto attrs = iface.getFortranAttrs()) { + return bitEnumContainsAny( + attrs.value(), fir::FortranVariableFlagsEnum::fortran_volatile); + } + } + return false; + } bool isBoxAddressOrValue() const { return hlfir::isBoxAddressOrValueType(getType()); } diff --git a/flang/include/flang/Optimizer/CodeGen/CGOps.td b/flang/include/flang/Optimizer/CodeGen/CGOps.td index f65291fc64c17..a0b92ae97df5c 100644 --- a/flang/include/flang/Optimizer/CodeGen/CGOps.td +++ b/flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -199,7 +199,7 @@ def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> Variadic:$indices, Variadic:$lenParams ); - let results = (outs fir_ReferenceType); + let results = (outs AnyReferenceType); let assemblyFormat = [{ $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td index 8e86d82f38df4..6b0f284f53d35 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td @@ -18,7 +18,7 @@ include "mlir/IR/EnumAttr.td" class fir_Attr : AttrDef; -def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">; +def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">; def FIRallocatable : I32BitEnumAttrCaseBit<"allocatable", 0>; def FIRasynchronous : I32BitEnumAttrCaseBit<"asynchronous", 1>; def FIRbind_c : I32BitEnumAttrCaseBit<"bind_c", 2>; diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 7147a2401baa7..eff3a9b756af0 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1766,7 +1766,7 @@ def fir_ArrayCoorOp : fir_Op<"array_coor", Variadic:$typeparams ); - let results = (outs fir_ReferenceType); + let results = (outs AnyReferenceType); let assemblyFormat = [{ $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index fd5bbbe44751f..cd40b6579cf4c 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -375,12 +375,36 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> { let extraClassDeclaration = [{ mlir::Type getElementType() const { return getEleTy(); } + static mlir::Type get(mlir::Type t, bool isVolatile); }]; let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; } +def fir_VolatileReferenceType : FIR_Type<"VolatileReference", "volatile_ref"> { + let summary = "Volatile reference to an entity type"; + + let description = [{ + The type of a volatile reference to an entity in memory. + }]; + + let parameters = (ins "mlir::Type":$eleTy); + + let builders = [TypeBuilderWithInferredContext< + (ins "mlir::Type":$elementType), [{ + return Base::get(elementType.getContext(), elementType); + }]>, + ]; + + let extraClassDeclaration = [{ + mlir::Type getElementType() const { return getEleTy(); } + }]; + + let genVerifyDecl = 1; + let assemblyFormat = "`<` $eleTy `>`"; +} + def fir_ShapeType : FIR_Type<"Shape", "shape"> { let summary = "shape of a multidimensional array object"; @@ -598,18 +622,28 @@ def AnyCompositeLike : TypeConstraint, "any composite">; +def AnyReferenceType : TypeConstraint, + "any reference type">; + // Reference types -def AnyReferenceLike : TypeConstraint, "any reference">; +def AnyReferenceLike + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_HeapType.predicate, fir_PointerType.predicate, + fir_LLVMPointerType.predicate]>, + "any reference">; def FuncType : TypeConstraint; def AnyCodeOrDataRefLike : TypeConstraint, "any code or data reference">; -def RefOrLLVMPtr : TypeConstraint, "fir.ref or fir.llvm_ptr">; +def RefOrLLVMPtr + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_LLVMPointerType.predicate]>, + "fir.ref or fir.llvm_ptr">; def AnyBoxLike : TypeConstraint, "any reference or box like">; -def AnyRefOrBox : TypeConstraint, "any reference or box">; +def AnyRefOrBox + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_HeapType.predicate, fir_PointerType.predicate, + IsBaseBoxTypePred]>, + "any reference or box">; def AnyRefOrBoxType : Type; def AnyShapeLike : TypeConstraint; // The legal types of global symbols -def AnyAddressableLike : TypeConstraint, "any addressable">; +def AnyAddressableLike + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + FunctionType.predicate]>, + "any addressable">; def ArrayOrBoxOrRecord : TypeConstraint, diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index f69930d5b53b3..cdd2b776fd186 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -124,7 +124,7 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments, /// Given a FIR memory type, and information about non default lower /// bounds, get the related HLFIR variable type. - static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds); + static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds, bool isVolatile=false); }]; let hasVerifier = 1; diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp index 226ba1e52c968..c741f1c1d2c76 100644 --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -1112,7 +1112,7 @@ class Fortran::lower::CallInterfaceImpl { if (obj.attrs.test(Attrs::Value)) isValueAttr = true; // TODO: do we want an mlir::Attribute as well? if (obj.attrs.test(Attrs::Volatile)) { - TODO(loc, "VOLATILE in procedure interface"); + // TODO(loc, "VOLATILE in procedure interface"); addMLIRAttr(fir::getVolatileAttrName()); } // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index dc00e0b13f583..5fb8952fcd0ee 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -36,6 +36,11 @@ namespace { +/// Determine if a given symbol has the VOLATILE attribute. +static bool isVolatileSymbol(const Fortran::semantics::Symbol &symbol) { + return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); +} + /// Lower Designators to HLFIR. class HlfirDesignatorBuilder { private: @@ -223,6 +228,18 @@ class HlfirDesignatorBuilder { designatorNode, getConverter().getFoldingContext(), /*namedConstantSectionsAreAlwaysContiguous=*/false)) return fir::BoxType::get(resultValueType); + + // Check if this should be a volatile reference + if constexpr (std::is_same_v, + Fortran::evaluate::SymbolRef>) { + if (isVolatileSymbol(designatorNode.get())) + return fir::VolatileReferenceType::get(resultValueType); + } else if constexpr (std::is_same_v, + Fortran::evaluate::Component>) { + if (isVolatileSymbol(designatorNode.GetLastSymbol())) + return fir::VolatileReferenceType::get(resultValueType); + } + // Other designators can be handled as raw addresses. return fir::ReferenceType::get(resultValueType); } diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 1a31ca33e9465..0d076ac59c8bc 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -756,11 +756,13 @@ std::pair hlfir::genVariableFirBaseShapeAndParams( auto params = fir::getTypeParams(exv); typeParams.append(params.begin(), params.end()); } - if (entity.isScalar()) + if (entity.isScalar()) { return {fir::getBase(exv), mlir::Value{}}; - if (auto variableInterface = entity.getIfVariableInterface()) + } + if (auto variableInterface = entity.getIfVariableInterface()) { return {fir::getBase(exv), asEmboxShape(loc, builder, exv, variableInterface.getShape())}; + } return {fir::getBase(exv), builder.createShape(loc, exv)}; } @@ -809,6 +811,9 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) { } else if (fir::isRecordWithTypeParameters(eleTy)) { return fir::BoxType::get(eleTy); } + if (variable.isVolatile()) { + return fir::VolatileReferenceType::get(eleTy); + } return fir::ReferenceType::get(eleTy); } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 2cb4cea58c2b0..f7ed8ce7c6d54 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -704,7 +704,6 @@ struct ConvertOpConversion : public fir::FIROpConversion { auto fromTy = convertType(fromFirTy); auto toTy = convertType(toFirTy); mlir::Value op0 = adaptor.getOperands()[0]; - if (fromFirTy == toFirTy) { rewriter.replaceOp(convert, op0); return mlir::success(); @@ -3217,6 +3216,9 @@ struct LoadOpConversion : public fir::FIROpConversion { matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type originalLoadTy = load.getMemref().getType(); + const bool isVolatile = + mlir::isa(originalLoadTy); mlir::Type llvmLoadTy = convertObjectType(load.getType()); if (auto boxTy = mlir::dyn_cast(load.getType())) { // fir.box is a special case because it is considered an ssa value in @@ -3256,7 +3258,7 @@ struct LoadOpConversion : public fir::FIROpConversion { rewriter.replaceOp(load, newBoxStorage); } else { auto loadOp = rewriter.create( - load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs()); + load.getLoc(), llvmLoadTy, adaptor.getOperands()[0], 0, isVolatile); if (std::optional optionalTag = load.getTbaa()) loadOp.setTBAATags(*optionalTag); else @@ -3531,6 +3533,9 @@ struct StoreOpConversion : public fir::FIROpConversion { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = store.getLoc(); mlir::Type storeTy = store.getValue().getType(); + mlir::Type originalStoreTy = store.getMemref().getType(); + const bool isVolatile = + mlir::isa(originalStoreTy); mlir::Value llvmValue = adaptor.getValue(); mlir::Value llvmMemref = adaptor.getMemref(); mlir::LLVM::AliasAnalysisOpInterface newOp; @@ -3541,10 +3546,11 @@ struct StoreOpConversion : public fir::FIROpConversion { TypePair boxTypePair{boxTy, llvmBoxTy}; mlir::Value boxSize = computeBoxSize(loc, boxTypePair, llvmValue, rewriter); - newOp = rewriter.create( - loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false); + newOp = rewriter.create(loc, llvmMemref, llvmValue, + boxSize, isVolatile); } else { - newOp = rewriter.create(loc, llvmValue, llvmMemref); + newOp = rewriter.create(loc, llvmValue, llvmMemref, + 0, isVolatile, false); } if (std::optional optionalTag = store.getTbaa()) newOp.setTBAATags(*optionalTag); diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp index 1a1d3a8cfb870..ca071c2fc08ac 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -103,6 +103,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA, }); addConversion( [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); + addConversion( + [&](fir::VolatileReferenceType ref) { return convertPointerLike(ref); }); addConversion([&](fir::SequenceType sequence) { return convertSequenceType(sequence); }); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 033d6453a619a..2cc7d9359f94d 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -270,7 +270,7 @@ llvm::LogicalResult fir::AllocaOp::verify() { if (verifyTypeParamCount(getInType(), numLenParams())) return emitOpError("LEN params do not correspond to type"); mlir::Type outType = getType(); - if (!mlir::isa(outType)) + if (!mlir::isa(outType)) return emitOpError("must be a !fir.ref type"); return mlir::success(); } @@ -305,8 +305,8 @@ static mlir::Type wrapAllocMemResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value - if (mlir::isa(intype)) + if (mlir::isa(intype)) return {}; return fir::HeapType::get(intype); } @@ -441,8 +441,9 @@ llvm::LogicalResult fir::ArrayCoorOp::verify() { if (sliceTy.getRank() != arrDim) return emitOpError("rank of dimension in slice mismatched"); } - if (!validTypeParams(getMemref().getType(), getTypeparams())) + if (!validTypeParams(getMemref().getType(), getTypeparams())) { return emitOpError("invalid type parameters"); + } return mlir::success(); } @@ -823,8 +824,8 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// static mlir::Type adjustedElementType(mlir::Type t) { - if (auto ty = mlir::dyn_cast(t)) { - auto eleTy = ty.getEleTy(); + if (fir::isa_ref_type(t)) { + mlir::Type eleTy = fir::dyn_cast_ptrEleTy(t); if (fir::isa_char(eleTy)) return eleTy; if (fir::isa_derived(eleTy)) @@ -1364,9 +1365,10 @@ bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { } bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { - return mlir::isa(ty); + return mlir::isa(ty); } static std::optional getVectorElementType(mlir::Type ty) { diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index dc0bee9b060c9..90614f1305c27 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -66,8 +66,8 @@ static bool isaIntegerType(mlir::Type ty) { bool verifyRecordMemberType(mlir::Type ty) { return !mlir::isa( - ty); + SliceType, FieldType, LenType, ReferenceType, TypeDescType, + VolatileReferenceType>(ty); } bool verifySameLists(llvm::ArrayRef a1, @@ -217,15 +217,17 @@ mlir::Type getDerivedType(mlir::Type ty) { mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) + .Case( + [](auto p) { return p.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) + .Case( + [](auto p) { return p.getEleTy(); }) .Case( [](auto p) { return unwrapRefType(p.getEleTy()); }) .Default([](mlir::Type) { return mlir::Type{}; }); @@ -596,6 +598,10 @@ std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap, } else if (auto refTy = mlir::dyn_cast_or_null(ty)) { name << "ref_"; ty = refTy.getEleTy(); + } else if (auto refTy = + mlir::dyn_cast_or_null(ty)) { + name << "volatile_ref_"; + ty = refTy.getEleTy(); } else if (auto ptrTy = mlir::dyn_cast_or_null(ty)) { name << "ptr_"; ty = ptrTy.getEleTy(); @@ -650,11 +656,12 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, return fir::SequenceType::get(seqTy.getShape(), newElementType); }) .Case([&](auto t) -> mlir::Type { - using FIRT = decltype(t); - return FIRT::get( - changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass)); - }) + fir::VolatileReferenceType, fir::ClassType>( + [&](auto t) -> mlir::Type { + using FIRT = decltype(t); + return FIRT::get(changeElementType(t.getEleTy(), newElementType, + turnBoxIntoClass)); + }) .Case([&](fir::BoxType t) -> mlir::Type { mlir::Type newInnerType = changeElementType(t.getEleTy(), newElementType, false); @@ -725,13 +732,16 @@ BoxProcType::verify(llvm::function_ref emitError, if (auto refTy = mlir::dyn_cast(eleTy)) if (mlir::isa(refTy)) return mlir::success(); + if (auto refTy = mlir::dyn_cast(eleTy)) + if (mlir::isa(refTy)) + return mlir::success(); return emitError() << "invalid type for boxproc" << eleTy << '\n'; } static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) { return mlir::isa(eleTy); + ReferenceType, VolatileReferenceType, TypeDescType>(eleTy); } //===----------------------------------------------------------------------===// @@ -1066,11 +1076,15 @@ void fir::ReferenceType::print(mlir::AsmPrinter &printer) const { printer << "<" << getEleTy() << '>'; } +mlir::Type fir::ReferenceType::get(mlir::Type t, bool isVolatile) { + return isVolatile ? (mlir::Type)fir::VolatileReferenceType::get(t) : (mlir::Type)fir::ReferenceType::get(t); +} + llvm::LogicalResult fir::ReferenceType::verify( llvm::function_ref emitError, mlir::Type eleTy) { if (mlir::isa(eleTy)) + ReferenceType, VolatileReferenceType, TypeDescType>(eleTy)) return emitError() << "cannot build a reference to type: " << eleTy << '\n'; return mlir::success(); } @@ -1147,7 +1161,8 @@ llvm::LogicalResult fir::SequenceType::verify( // DIMENSION attribute can only be applied to an intrinsic or record type if (mlir::isa(eleTy)) + ReferenceType, VolatileReferenceType, TypeDescType, + SequenceType>(eleTy)) return emitError() << "cannot build an array of this element type: " << eleTy << '\n'; return mlir::success(); @@ -1220,7 +1235,7 @@ llvm::LogicalResult fir::TypeDescType::verify( mlir::Type eleTy) { if (mlir::isa(eleTy)) + VolatileReferenceType, TypeDescType>(eleTy)) return emitError() << "cannot build a type descriptor of type: " << eleTy << '\n'; return mlir::success(); @@ -1319,11 +1334,12 @@ changeTypeShape(mlir::Type type, return fir::SequenceType::get(*newShape, seqTy.getEleTy()); return seqTy.getEleTy(); }) - .Case([&](auto t) -> mlir::Type { - using FIRT = decltype(t); - return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); - }) + .Case( + [&](auto t) -> mlir::Type { + using FIRT = decltype(t); + return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); + }) .Default([&](mlir::Type t) -> mlir::Type { assert((fir::isa_trivial(t) || llvm::isa(t) || llvm::isa(t)) && @@ -1393,10 +1409,13 @@ void FIROpsDialect::registerTypes() { addTypes(); + VolatileReferenceType, SequenceType, ShapeType, ShapeShiftType, + ShiftType, SliceType, TypeDescType, fir::VectorType, + fir::DummyScopeType>(); fir::ReferenceType::attachInterface< OpenMPPointerLikeModel>(*getContext()); + fir::VolatileReferenceType::attachInterface< + OpenMPPointerLikeModel>(*getContext()); fir::PointerType::attachInterface>( *getContext()); fir::HeapType::attachInterface>( @@ -1467,3 +1486,8 @@ fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty, return *result; TODO(loc, "computing size of a component"); } + +llvm::LogicalResult fir::VolatileReferenceType::verify( + llvm::function_ref, mlir::Type) { + return mlir::success(); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index cb77aef74acd5..827564aa60471 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -82,7 +82,8 @@ void hlfir::ExprType::print(mlir::AsmPrinter &printer) const { bool hlfir::isFortranVariableType(mlir::Type type) { return llvm::TypeSwitch(type) - .Case([](auto p) { + .Case([](auto p) { mlir::Type eleType = p.getEleTy(); return mlir::isa(eleType) || !fir::hasDynamicSize(eleType); diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 8851a3a7187b9..f7abeb202a6a8 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -180,7 +180,8 @@ void hlfir::AssignOp::getEffects( /// Given a FIR memory type, and information about non default lower bounds, get /// the related HLFIR variable type. mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, - bool hasExplicitLowerBounds) { + bool hasExplicitLowerBounds, + bool isVolatile) { mlir::Type type = fir::unwrapRefType(inputType); if (mlir::isa(type)) return inputType; @@ -196,6 +197,14 @@ mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, fir::isRecordWithTypeParameters(eleType); if (hasExplicitLowerBounds || hasDynamicExtents || hasDynamicLengthParams) return fir::BoxType::get(type); + + // If this is a reference type and has the volatile attribute, use + // VolatileReferenceType + if (isVolatile && mlir::isa(inputType)) { + auto refType = mlir::cast(inputType); + return fir::VolatileReferenceType::get(refType.getEleTy()); + } + return inputType; } @@ -214,18 +223,49 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, auto nameAttr = builder.getStringAttr(uniq_name); mlir::Type inputType = memref.getType(); bool hasExplicitLbs = hasExplicitLowerBounds(shape); + bool isVolatile = false; + if (fortran_attrs && mlir::isa(inputType) && + bitEnumContainsAny(fortran_attrs.getFlags(), + fir::FortranVariableFlagsEnum::fortran_volatile)) { + auto refType = mlir::cast(inputType); + isVolatile = true; + inputType = fir::VolatileReferenceType::get(refType.getEleTy()); + memref = builder.create(memref.getLoc(), inputType, memref); + } mlir::Type hlfirVariableType = - getHLFIRVariableType(inputType, hasExplicitLbs); + getHLFIRVariableType(inputType, hasExplicitLbs, isVolatile); + build(builder, result, {hlfirVariableType, inputType}, memref, shape, typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr); } +static bool hlfirVariableTypeCompatible(mlir::Type memrefType, + mlir::Type outputType) { + // if the input and output types don't match, they are still compatible ONLY + // if this is due to the variable being declared volatile. + if (auto inputRefTy = mlir::dyn_cast(memrefType)) { + if (auto hlfirRefTy = + mlir::dyn_cast(outputType)) { + return hlfirRefTy.getEleTy() == inputRefTy.getEleTy(); + } + } + return memrefType == outputType; +} + llvm::LogicalResult hlfir::DeclareOp::verify() { if (getMemref().getType() != getResult(1).getType()) return emitOpError("second result type must match input memref type"); + fir::FortranVariableFlagsAttr attrs; + bool isVolatile = false; + if (getFortranAttrs().has_value()) { + auto flagsEnum = getFortranAttrs().value(); + isVolatile = bitEnumContainsAny( + flagsEnum, fir::FortranVariableFlagsEnum::fortran_volatile); + attrs = fir::FortranVariableFlagsAttr::get(getContext(), flagsEnum); + } mlir::Type hlfirVariableType = getHLFIRVariableType( - getMemref().getType(), hasExplicitLowerBounds(getShape())); - if (hlfirVariableType != getResult(0).getType()) + getMemref().getType(), hasExplicitLowerBounds(getShape()), isVolatile); + if (!hlfirVariableTypeCompatible(getResult(0).getType(), hlfirVariableType)) return emitOpError("first result type is inconsistent with variable " "properties: expected ") << hlfirVariableType; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 496a5560ac615..98b64634f7d2f 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -387,7 +387,9 @@ class DeclareOpConversion : public mlir::OpRewritePattern { } else { if (hlfirBaseType != firBase.getType()) { declareOp.emitOpError() - << "unhandled HLFIR variable type '" << hlfirBaseType << "'\n"; + << "unhandled HLFIR variable type '" << hlfirBaseType + << "' does not match fir type '" << firBase.getType() + << "' with memref '" << memref << "'\n"; return mlir::failure(); } hlfirBase = firBase; @@ -418,10 +420,13 @@ class DesignateOpConversion firstElementIndices.push_back(indices[i]); i = i + (isTriplet ? 3 : 1); } - mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy); + auto designateResultType = designate.getResult().getType(); + auto isVolatile = + mlir::isa(designateResultType); + mlir::Type refTy = fir::ReferenceType::get(baseEleTy, isVolatile); base = builder.create( - loc, arrayCoorType, base, shape, - /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters); + loc, refTy, base, shape, /*slice=*/mlir::Value{}, + firstElementIndices, firBaseTypeParameters); return base; } @@ -436,7 +441,6 @@ class DesignateOpConversion fir::FirOpBuilder builder(rewriter, designate.getOperation()); hlfir::Entity baseEntity(designate.getMemref()); - if (baseEntity.isMutableBox()) TODO(loc, "hlfir::designate load of pointer or allocatable"); @@ -581,8 +585,9 @@ class DesignateOpConversion // shape. The base may be an array, or a scalar. mlir::Type resultAddressType = designateResultType; if (auto boxCharType = - mlir::dyn_cast(designateResultType)) + mlir::dyn_cast(designateResultType)) { resultAddressType = fir::ReferenceType::get(boxCharType.getEleTy()); + } // Array element indexing. if (!designate.getIndices().empty()) {