From 080c9efda7dd117b7b0d59a939f838f4f986976b Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 18 Mar 2025 10:03:24 -0700 Subject: [PATCH 1/5] init prototype --- .../flang/Optimizer/Builder/HLFIRTools.h | 9 +++ .../include/flang/Optimizer/CodeGen/CGOps.td | 10 ++- .../flang/Optimizer/Dialect/FIRAttr.td | 2 +- .../include/flang/Optimizer/Dialect/FIROps.td | 2 +- .../flang/Optimizer/Dialect/FIRTypes.td | 34 ++++++++- .../include/flang/Optimizer/HLFIR/HLFIROps.td | 1 + flang/lib/Lower/CallInterface.cpp | 2 +- flang/lib/Lower/ConvertExprToHLFIR.cpp | 20 +++++ flang/lib/Optimizer/Builder/HLFIRTools.cpp | 13 +++- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 22 +++++- flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 2 + flang/lib/Optimizer/Dialect/FIROps.cpp | 24 +++--- flang/lib/Optimizer/Dialect/FIRType.cpp | 35 ++++++--- flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 2 +- flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 62 +++++++++++++++- .../HLFIR/Transforms/ConvertToFIR.cpp | 74 +++++++++++++++++-- 16 files changed, 267 insertions(+), 47 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index ac80873dc374f..0128e8f558f0b 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -51,6 +51,7 @@ inline bool isFortranEntityWithAttributes(mlir::Value value) { class Entity : public mlir::Value { public: explicit Entity(mlir::Value value) : mlir::Value(value) { + llvm::dbgs() << value << " " << value.getType() << "\n"; assert(isFortranEntity(value) && "must be a value representing a Fortran value or variable like"); } @@ -62,6 +63,14 @@ class Entity : public mlir::Value { bool isProcedurePointer() const { return hlfir::isFortranProcedurePointerType(getType()); } + bool isVolatile() const { + if (auto iface = getIfVariableInterface()) { + if (auto attrs = iface.getFortranAttrs()) { + return bitEnumContainsAny(attrs.value(), fir::FortranVariableFlagsEnum::fortran_volatile); + } + } + return false; + } bool isBoxAddressOrValue() const { return hlfir::isBoxAddressOrValueType(getType()); } diff --git a/flang/include/flang/Optimizer/CodeGen/CGOps.td b/flang/include/flang/Optimizer/CodeGen/CGOps.td index f65291fc64c17..adb20f8b7264b 100644 --- a/flang/include/flang/Optimizer/CodeGen/CGOps.td +++ b/flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -30,6 +30,14 @@ def fircg_Dialect : Dialect { class fircg_Op traits> : Op; +def fircg_VolatileLoadOp : fircg_Op<"volatile_load", []> { + let summary = "for internal conversion only"; + let description = [{ + arg must be ref of volatile + }]; + let arguments = (ins Arg:$memref); +} + // Extended embox operation. def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> { let summary = "for internal conversion only"; @@ -199,7 +207,7 @@ def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> Variadic:$indices, Variadic:$lenParams ); - let results = (outs fir_ReferenceType); + let results = (outs AnyReferenceType); let assemblyFormat = [{ $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td index 8e86d82f38df4..6b0f284f53d35 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td @@ -18,7 +18,7 @@ include "mlir/IR/EnumAttr.td" class fir_Attr : AttrDef; -def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">; +def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">; def FIRallocatable : I32BitEnumAttrCaseBit<"allocatable", 0>; def FIRasynchronous : I32BitEnumAttrCaseBit<"asynchronous", 1>; def FIRbind_c : I32BitEnumAttrCaseBit<"bind_c", 2>; diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 7147a2401baa7..eff3a9b756af0 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1766,7 +1766,7 @@ def fir_ArrayCoorOp : fir_Op<"array_coor", Variadic:$typeparams ); - let results = (outs fir_ReferenceType); + let results = (outs AnyReferenceType); let assemblyFormat = [{ $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index fd5bbbe44751f..2eb0e5ba1af80 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -381,6 +381,29 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> { let hasCustomAssemblyFormat = 1; } +def fir_VolatileReferenceType : FIR_Type<"VolatileReference", "volatile_ref"> { + let summary = "Volatile reference to an entity type"; + + let description = [{ + The type of a volatile reference to an entity in memory. + }]; + + let parameters = (ins "mlir::Type":$eleTy); + + let builders = [ + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + return Base::get(elementType.getContext(), elementType); + }]>, + ]; + + let extraClassDeclaration = [{ + mlir::Type getElementType() const { return getEleTy(); } + }]; + + let genVerifyDecl = 1; + let assemblyFormat = "`<` $eleTy `>`"; +} + def fir_ShapeType : FIR_Type<"Shape", "shape"> { let summary = "shape of a multidimensional array object"; @@ -598,9 +621,12 @@ def AnyCompositeLike : TypeConstraint, "any composite">; +def AnyReferenceType : TypeConstraint, "any reference type">; + // Reference types def AnyReferenceLike : TypeConstraint, "any reference">; def FuncType : TypeConstraint; @@ -609,7 +635,7 @@ def AnyCodeOrDataRefLike : TypeConstraint, "any code or data reference">; def RefOrLLVMPtr : TypeConstraint, "fir.ref or fir.llvm_ptr">; + fir_VolatileReferenceType.predicate, fir_LLVMPointerType.predicate]>, "fir.ref or fir.llvm_ptr">; def AnyBoxLike : TypeConstraint, "any reference or box like">; def AnyRefOrBox : TypeConstraint, "any reference or box">; def AnyRefOrBoxType : Type; @@ -652,7 +678,7 @@ def AnyCoordinateType : Type; // The legal types of global symbols def AnyAddressableLike : TypeConstraint, "any addressable">; + fir_VolatileReferenceType.predicate, FunctionType.predicate]>, "any addressable">; def ArrayOrBoxOrRecord : TypeConstraint, diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index f69930d5b53b3..5935d0b356d38 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -125,6 +125,7 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments, /// Given a FIR memory type, and information about non default lower /// bounds, get the related HLFIR variable type. static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds); + static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds, bool isVolatile); }]; let hasVerifier = 1; diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp index 226ba1e52c968..c741f1c1d2c76 100644 --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -1112,7 +1112,7 @@ class Fortran::lower::CallInterfaceImpl { if (obj.attrs.test(Attrs::Value)) isValueAttr = true; // TODO: do we want an mlir::Attribute as well? if (obj.attrs.test(Attrs::Volatile)) { - TODO(loc, "VOLATILE in procedure interface"); + // TODO(loc, "VOLATILE in procedure interface"); addMLIRAttr(fir::getVolatileAttrName()); } // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index dc00e0b13f583..23d938d164bef 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -36,6 +36,16 @@ namespace { +/// Determine if a given symbol or designator has the VOLATILE attribute. +static bool isVolatileDesignator(const Fortran::semantics::Symbol &symbol) { + return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); +} + +/// Determine if a given symbol has the VOLATILE attribute. +static bool isVolatileSymbol(const Fortran::semantics::Symbol &symbol) { + return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); +} + /// Lower Designators to HLFIR. class HlfirDesignatorBuilder { private: @@ -223,6 +233,16 @@ class HlfirDesignatorBuilder { designatorNode, getConverter().getFoldingContext(), /*namedConstantSectionsAreAlwaysContiguous=*/false)) return fir::BoxType::get(resultValueType); + + // Check if this should be a volatile reference + if constexpr (std::is_same_v, Fortran::evaluate::SymbolRef>) { + if (isVolatileSymbol(designatorNode.get())) + return fir::VolatileReferenceType::get(resultValueType); + } else if constexpr (std::is_same_v, Fortran::evaluate::Component>) { + if (isVolatileSymbol(designatorNode.GetLastSymbol())) + return fir::VolatileReferenceType::get(resultValueType); + } + // Other designators can be handled as raw addresses. return fir::ReferenceType::get(resultValueType); } diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 1a31ca33e9465..e1d1f94e49edc 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -751,16 +751,22 @@ std::pair hlfir::genVariableFirBaseShapeAndParams( mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, llvm::SmallVectorImpl &typeParams) { auto [exv, cleanup] = translateToExtendedValue(loc, builder, entity); + llvm::dbgs() << "exv: " << exv << "\n"; assert(!cleanup && "variable to Exv should not produce cleanup"); if (entity.hasLengthParameters()) { + llvm::dbgs() << "entity.hasLengthParameters()\n"; auto params = fir::getTypeParams(exv); typeParams.append(params.begin(), params.end()); } - if (entity.isScalar()) + if (entity.isScalar()) { + llvm::dbgs() << "entity.isScalar()\n"; return {fir::getBase(exv), mlir::Value{}}; - if (auto variableInterface = entity.getIfVariableInterface()) + } + if (auto variableInterface = entity.getIfVariableInterface()) { + llvm::dbgs() << "variableInterface: " << variableInterface << "\n"; return {fir::getBase(exv), asEmboxShape(loc, builder, exv, variableInterface.getShape())}; + } return {fir::getBase(exv), builder.createShape(loc, exv)}; } @@ -809,6 +815,9 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) { } else if (fir::isRecordWithTypeParameters(eleTy)) { return fir::BoxType::get(eleTy); } + if (variable.isVolatile()) { + return fir::VolatileReferenceType::get(eleTy); + } return fir::ReferenceType::get(eleTy); } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 2cb4cea58c2b0..c3613832d2872 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -699,12 +699,17 @@ struct ConvertOpConversion : public fir::FIROpConversion { llvm::LogicalResult matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + llvm::dbgs() << "ConvertOpConversion\n"; auto fromFirTy = convert.getValue().getType(); auto toFirTy = convert.getRes().getType(); + llvm::dbgs() << "fromFirTy: " << fromFirTy << "\n"; + llvm::dbgs() << "toFirTy: " << toFirTy << "\n"; auto fromTy = convertType(fromFirTy); auto toTy = convertType(toFirTy); + llvm::dbgs() << "fromTy: " << fromTy << "\n"; + llvm::dbgs() << "toTy: " << toTy << "\n"; mlir::Value op0 = adaptor.getOperands()[0]; - + llvm::dbgs() << "op0: " << op0 << "\n"; if (fromFirTy == toFirTy) { rewriter.replaceOp(convert, op0); return mlir::success(); @@ -3217,6 +3222,9 @@ struct LoadOpConversion : public fir::FIROpConversion { matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type originalLoadTy = load.getMemref().getType(); + const bool isVolatile = mlir::isa(originalLoadTy); + auto volatileAttr = mlir::UnitAttr::get(load.getContext()); mlir::Type llvmLoadTy = convertObjectType(load.getType()); if (auto boxTy = mlir::dyn_cast(load.getType())) { // fir.box is a special case because it is considered an ssa value in @@ -3256,7 +3264,7 @@ struct LoadOpConversion : public fir::FIROpConversion { rewriter.replaceOp(load, newBoxStorage); } else { auto loadOp = rewriter.create( - load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs()); + load.getLoc(), llvmLoadTy, adaptor.getOperands()[0], 0, isVolatile); if (std::optional optionalTag = load.getTbaa()) loadOp.setTBAATags(*optionalTag); else @@ -3531,9 +3539,14 @@ struct StoreOpConversion : public fir::FIROpConversion { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = store.getLoc(); mlir::Type storeTy = store.getValue().getType(); + mlir::Type originalStoreTy = store.getMemref().getType(); + const bool isVolatile = mlir::isa(originalStoreTy); mlir::Value llvmValue = adaptor.getValue(); mlir::Value llvmMemref = adaptor.getMemref(); mlir::LLVM::AliasAnalysisOpInterface newOp; + llvm::dbgs() << "storeTy: " << storeTy << "\n"; + llvm::dbgs() << "originalStoreTy: " << originalStoreTy << "\n"; + llvm::dbgs() << "isVolatile: " << isVolatile << "\n"; if (auto boxTy = mlir::dyn_cast(storeTy)) { mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy); // Always use memcpy because LLVM is not as effective at optimizing @@ -3542,10 +3555,11 @@ struct StoreOpConversion : public fir::FIROpConversion { mlir::Value boxSize = computeBoxSize(loc, boxTypePair, llvmValue, rewriter); newOp = rewriter.create( - loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false); + loc, llvmMemref, llvmValue, boxSize, isVolatile); } else { - newOp = rewriter.create(loc, llvmValue, llvmMemref); + newOp = rewriter.create(loc, llvmValue, llvmMemref, 0, isVolatile, false); } + llvm::dbgs() << "newOp: " << newOp << "\n"; if (std::optional optionalTag = store.getTbaa()) newOp.setTBAATags(*optionalTag); else diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp index 1a1d3a8cfb870..ca071c2fc08ac 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -103,6 +103,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA, }); addConversion( [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); + addConversion( + [&](fir::VolatileReferenceType ref) { return convertPointerLike(ref); }); addConversion([&](fir::SequenceType sequence) { return convertSequenceType(sequence); }); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 033d6453a619a..8ec5c8fee69c7 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -270,7 +270,7 @@ llvm::LogicalResult fir::AllocaOp::verify() { if (verifyTypeParamCount(getInType(), numLenParams())) return emitOpError("LEN params do not correspond to type"); mlir::Type outType = getType(); - if (!mlir::isa(outType)) + if (!mlir::isa(outType)) return emitOpError("must be a !fir.ref type"); return mlir::success(); } @@ -305,7 +305,7 @@ static mlir::Type wrapAllocMemResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value - if (mlir::isa(intype)) return {}; return fir::HeapType::get(intype); @@ -823,15 +823,18 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// static mlir::Type adjustedElementType(mlir::Type t) { + mlir::Type eleTy; if (auto ty = mlir::dyn_cast(t)) { - auto eleTy = ty.getEleTy(); - if (fir::isa_char(eleTy)) - return eleTy; - if (fir::isa_derived(eleTy)) - return eleTy; - if (mlir::isa(eleTy)) - return eleTy; + eleTy = ty.getEleTy(); + } else if (auto volType = mlir::dyn_cast(t)) { + eleTy = volType.getEleTy(); } + if (fir::isa_char(eleTy)) + return eleTy; + if (fir::isa_derived(eleTy)) + return eleTy; + if (mlir::isa(eleTy)) + return eleTy; return t; } @@ -1364,7 +1367,8 @@ bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { } bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { - return mlir::isa(ty); } diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index dc0bee9b060c9..ea663cc42e647 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -66,8 +66,8 @@ static bool isaIntegerType(mlir::Type ty) { bool verifyRecordMemberType(mlir::Type ty) { return !mlir::isa( - ty); + SliceType, FieldType, LenType, ReferenceType, TypeDescType, + VolatileReferenceType>(ty); } bool verifySameLists(llvm::ArrayRef a1, @@ -217,14 +217,14 @@ mlir::Type getDerivedType(mlir::Type ty) { mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) .Case( [](auto p) { return unwrapRefType(p.getEleTy()); }) @@ -596,6 +596,9 @@ std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap, } else if (auto refTy = mlir::dyn_cast_or_null(ty)) { name << "ref_"; ty = refTy.getEleTy(); + } else if (auto refTy = mlir::dyn_cast_or_null(ty)) { + name << "volatile_ref_"; + ty = refTy.getEleTy(); } else if (auto ptrTy = mlir::dyn_cast_or_null(ty)) { name << "ptr_"; ty = ptrTy.getEleTy(); @@ -649,7 +652,7 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, .Case([&](fir::SequenceType seqTy) -> mlir::Type { return fir::SequenceType::get(seqTy.getShape(), newElementType); }) - .Case([&](auto t) -> mlir::Type { using FIRT = decltype(t); return FIRT::get( @@ -725,13 +728,16 @@ BoxProcType::verify(llvm::function_ref emitError, if (auto refTy = mlir::dyn_cast(eleTy)) if (mlir::isa(refTy)) return mlir::success(); + if (auto refTy = mlir::dyn_cast(eleTy)) + if (mlir::isa(refTy)) + return mlir::success(); return emitError() << "invalid type for boxproc" << eleTy << '\n'; } static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) { return mlir::isa(eleTy); + ReferenceType, VolatileReferenceType, TypeDescType>(eleTy); } //===----------------------------------------------------------------------===// @@ -1070,7 +1076,7 @@ llvm::LogicalResult fir::ReferenceType::verify( llvm::function_ref emitError, mlir::Type eleTy) { if (mlir::isa(eleTy)) + ReferenceType, VolatileReferenceType, TypeDescType>(eleTy)) return emitError() << "cannot build a reference to type: " << eleTy << '\n'; return mlir::success(); } @@ -1147,7 +1153,7 @@ llvm::LogicalResult fir::SequenceType::verify( // DIMENSION attribute can only be applied to an intrinsic or record type if (mlir::isa(eleTy)) + ReferenceType, VolatileReferenceType, TypeDescType, SequenceType>(eleTy)) return emitError() << "cannot build an array of this element type: " << eleTy << '\n'; return mlir::success(); @@ -1220,7 +1226,7 @@ llvm::LogicalResult fir::TypeDescType::verify( mlir::Type eleTy) { if (mlir::isa(eleTy)) + VolatileReferenceType, TypeDescType>(eleTy)) return emitError() << "cannot build a type descriptor of type: " << eleTy << '\n'; return mlir::success(); @@ -1319,7 +1325,7 @@ changeTypeShape(mlir::Type type, return fir::SequenceType::get(*newShape, seqTy.getEleTy()); return seqTy.getEleTy(); }) - .Case([&](auto t) -> mlir::Type { using FIRT = decltype(t); return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); @@ -1393,10 +1399,12 @@ void FIROpsDialect::registerTypes() { addTypes(); fir::ReferenceType::attachInterface< OpenMPPointerLikeModel>(*getContext()); + fir::VolatileReferenceType::attachInterface< + OpenMPPointerLikeModel>(*getContext()); fir::PointerType::attachInterface>( *getContext()); fir::HeapType::attachInterface>( @@ -1467,3 +1475,8 @@ fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty, return *result; TODO(loc, "computing size of a component"); } + +llvm::LogicalResult +fir::VolatileReferenceType::verify(llvm::function_ref, mlir::Type) { + return mlir::success(); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index cb77aef74acd5..c2a485c66e0ec 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -82,7 +82,7 @@ void hlfir::ExprType::print(mlir::AsmPrinter &printer) const { bool hlfir::isFortranVariableType(mlir::Type type) { return llvm::TypeSwitch(type) - .Case([](auto p) { + .Case([](auto p) { mlir::Type eleType = p.getEleTy(); return mlir::isa(eleType) || !fir::hasDynamicSize(eleType); diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 8851a3a7187b9..d63a2ac39f6b5 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -199,6 +199,35 @@ mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, return inputType; } +// Updated version with volatile support +mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, + bool hasExplicitLowerBounds, + bool isVolatile) { + mlir::Type type = fir::unwrapRefType(inputType); + if (mlir::isa(type)) + return inputType; + if (auto charType = mlir::dyn_cast(type)) + if (charType.hasDynamicLen()) + return fir::BoxCharType::get(charType.getContext(), charType.getFKind()); + + auto seqType = mlir::dyn_cast(type); + bool hasDynamicExtents = + seqType && fir::sequenceWithNonConstantShape(seqType); + mlir::Type eleType = seqType ? seqType.getEleTy() : type; + bool hasDynamicLengthParams = fir::characterWithDynamicLen(eleType) || + fir::isRecordWithTypeParameters(eleType); + if (hasExplicitLowerBounds || hasDynamicExtents || hasDynamicLengthParams) + return fir::BoxType::get(type); + + // If this is a reference type and has the volatile attribute, use VolatileReferenceType + if (isVolatile && mlir::isa(inputType)) { + auto refType = mlir::cast(inputType); + return fir::VolatileReferenceType::get(refType.getEleTy()); + } + + return inputType; +} + static bool hasExplicitLowerBounds(mlir::Value shape) { return shape && mlir::isa(shape.getType()); @@ -214,18 +243,45 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, auto nameAttr = builder.getStringAttr(uniq_name); mlir::Type inputType = memref.getType(); bool hasExplicitLbs = hasExplicitLowerBounds(shape); + bool isVolatile = false; + if (fortran_attrs && mlir::isa(inputType) && + bitEnumContainsAny(fortran_attrs.getFlags(), fir::FortranVariableFlagsEnum::fortran_volatile)) { + auto refType = mlir::cast(inputType); + isVolatile = true; + inputType = fir::VolatileReferenceType::get(refType.getEleTy()); + memref = builder.create(memref.getLoc(), inputType, memref); + } mlir::Type hlfirVariableType = - getHLFIRVariableType(inputType, hasExplicitLbs); + getHLFIRVariableType(inputType, hasExplicitLbs, isVolatile); + build(builder, result, {hlfirVariableType, inputType}, memref, shape, typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr); } +static bool hlfirVariableTypeCompatible(mlir::Type memrefType, mlir::Type outputType) { + // if the input and output types don't match, they are still compatible ONLY if this is + // due to the variable being declared volatile. + if (auto inputRefTy = mlir::dyn_cast(memrefType)) { + if (auto hlfirRefTy = mlir::dyn_cast(outputType)) { + return hlfirRefTy.getEleTy() == inputRefTy.getEleTy(); + } + } + return memrefType == outputType; +} + llvm::LogicalResult hlfir::DeclareOp::verify() { if (getMemref().getType() != getResult(1).getType()) return emitOpError("second result type must match input memref type"); + fir::FortranVariableFlagsAttr attrs; + bool isVolatile = false; + if (getFortranAttrs().has_value()) { + auto flagsEnum = getFortranAttrs().value(); + isVolatile = bitEnumContainsAny(flagsEnum, fir::FortranVariableFlagsEnum::fortran_volatile); + attrs = fir::FortranVariableFlagsAttr::get(getContext(), flagsEnum); + } mlir::Type hlfirVariableType = getHLFIRVariableType( - getMemref().getType(), hasExplicitLowerBounds(getShape())); - if (hlfirVariableType != getResult(0).getType()) + getMemref().getType(), hasExplicitLowerBounds(getShape()), isVolatile); + if (!hlfirVariableTypeCompatible(getResult(0).getType(), hlfirVariableType)) return emitOpError("first result type is inconsistent with variable " "properties: expected ") << hlfirVariableType; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 496a5560ac615..d591ff67db2d4 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -295,6 +295,7 @@ class DeclareOpConversion : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { mlir::Location loc = declareOp->getLoc(); mlir::Value memref = declareOp.getMemref(); + mlir::Type memrefType = memref.getType(); fir::FortranVariableFlagsAttr fortranAttrs; cuf::DataAttributeAttr dataAttr; if (auto attrs = declareOp.getFortranAttrs()) @@ -302,8 +303,11 @@ class DeclareOpConversion : public mlir::OpRewritePattern { fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs); if (auto attr = declareOp.getDataAttr()) dataAttr = cuf::DataAttributeAttr::get(rewriter.getContext(), *attr); + if (auto volType = mlir::dyn_cast(declareOp.getResult(0).getType())) { + memrefType = volType; + } auto firDeclareOp = rewriter.create( - loc, memref.getType(), memref, declareOp.getShape(), + loc, memrefType, memref, declareOp.getShape(), declareOp.getTypeparams(), declareOp.getDummyScope(), declareOp.getUniqName(), fortranAttrs, dataAttr); @@ -385,12 +389,19 @@ class DeclareOpConversion : public mlir::OpRewritePattern { hlfirBase = rewriter.create( loc, hlfirBaseType, firBase, declareOp.getTypeparams()[0]); } else { - if (hlfirBaseType != firBase.getType()) { + llvm::dbgs() << "hlfirBaseType: " << hlfirBaseType << "\n"; + llvm::dbgs() << "firBase.getType(): " << firBase.getType() << "\n"; + auto volType = mlir::dyn_cast(hlfirBaseType); + auto refType = mlir::dyn_cast(firBase.getType()); + if (volType && refType && volType.getEleTy() == refType.getEleTy()) { + firBase = hlfirBase; + } else if (hlfirBaseType != firBase.getType()) { declareOp.emitOpError() - << "unhandled HLFIR variable type '" << hlfirBaseType << "'\n"; + << "unhandled HLFIR variable type '" << hlfirBaseType << "' does not match fir type '" << firBase.getType() << "' with memref '" << memref << "'\n"; return mlir::failure(); + } else { + hlfirBase = firBase; } - hlfirBase = firBase; } rewriter.replaceOp(declareOp, {hlfirBase, firBase}); return mlir::success(); @@ -408,6 +419,13 @@ class DesignateOpConversion mlir::Value shape, const llvm::SmallVector &firBaseTypeParameters) { assert(!designate.getIndices().empty()); + llvm::dbgs() << "genSubscriptBeginAddr\n"; + llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; + llvm::dbgs() << "base: " << base << "\n"; + llvm::dbgs() << "shape: " << shape << "\n"; + if (auto decl = mlir::dyn_cast(base.getDefiningOp())) { + base = decl.getResult(0); + } llvm::SmallVector firstElementIndices; auto indices = designate.getIndices(); int i = 0; @@ -418,10 +436,19 @@ class DesignateOpConversion firstElementIndices.push_back(indices[i]); i = i + (isTriplet ? 3 : 1); } - mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy); + llvm::dbgs() << "building array coor\n"; + auto designateResultType = designate.getResult().getType(); + llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; + auto isVolatile = mlir::isa(designateResultType); + llvm::dbgs() << "isVolatile: " << isVolatile << "\n"; + mlir::Type refTy = fir::ReferenceType::get(baseEleTy); + mlir::Type volTy = fir::VolatileReferenceType::get(baseEleTy); + llvm::dbgs() << "refTy: " << refTy << "\n"; + llvm::dbgs() << "volTy: " << volTy << "\n"; base = builder.create( - loc, arrayCoorType, base, shape, + loc, isVolatile ? volTy : refTy, base, shape, /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters); + llvm::dbgs() << "base: " << base << "\n"; return base; } @@ -432,24 +459,35 @@ class DesignateOpConversion llvm::LogicalResult matchAndRewrite(hlfir::DesignateOp designate, mlir::PatternRewriter &rewriter) const override { + llvm::dbgs() << "DesignateOpConversion\n"; mlir::Location loc = designate.getLoc(); fir::FirOpBuilder builder(rewriter, designate.getOperation()); hlfir::Entity baseEntity(designate.getMemref()); - + bool isVolatile = mlir::isa(designate.getResult().getType()); + llvm::dbgs() << "baseEntity: " << baseEntity << "\n"; + mlir::Type baseType = baseEntity.getBase().getType(); + llvm::dbgs() << "baseType: " << baseType << "\n"; if (baseEntity.isMutableBox()) TODO(loc, "hlfir::designate load of pointer or allocatable"); mlir::Type designateResultType = designate.getResult().getType(); + llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; llvm::SmallVector firBaseTypeParameters; auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams( loc, builder, baseEntity, firBaseTypeParameters); mlir::Type baseEleTy = hlfir::getFortranElementType(base.getType()); mlir::Type resultEleTy = hlfir::getFortranElementType(designateResultType); + llvm::dbgs() << "base: " << base << " " << base.getType() << "\n"; + llvm::dbgs() << "shape: " << shape << " " << shape.getType() << "\n"; + llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; + llvm::dbgs() << "resultEleTy: " << resultEleTy << "\n"; mlir::Value fieldIndex; if (designate.getComponent()) { + llvm::dbgs() << "designate.getComponent(): " << designate.getComponent() << "\n"; mlir::Type baseRecordType = baseEntity.getFortranElementType(); + llvm::dbgs() << "baseRecordType: " << baseRecordType << "\n"; if (fir::isRecordWithTypeParameters(baseRecordType)) TODO(loc, "hlfir.designate with a parametrized derived type base"); fieldIndex = builder.create( @@ -457,6 +495,7 @@ class DesignateOpConversion designate.getComponent().value(), baseRecordType, /*typeParams=*/mlir::ValueRange{}); if (baseEntity.isScalar()) { + llvm::dbgs() << "baseEntity.isScalar()\n"; // Component refs of scalar base right away: // - scalar%scalar_component [substring|complex_part] or // - scalar%static_size_array_comp @@ -464,9 +503,13 @@ class DesignateOpConversion mlir::Type componentType = mlir::cast(baseEleTy).getType( designate.getComponent().value()); + llvm::dbgs() << "componentType: " << componentType << "\n"; mlir::Type coorTy = fir::ReferenceType::get(componentType); + llvm::dbgs() << "coorTy: " << coorTy << "\n"; base = builder.create(loc, coorTy, base, fieldIndex); + llvm::dbgs() << "base: " << base << "\n"; if (mlir::isa(componentType)) { + llvm::dbgs() << "mlir::isa(componentType)\n"; auto variableInterface = mlir::cast( designate.getOperation()); if (variableInterface.isAllocatable() || @@ -479,6 +522,8 @@ class DesignateOpConversion } baseEleTy = hlfir::getFortranElementType(componentType); shape = designate.getComponentShape(); + llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; + llvm::dbgs() << "shape: " << shape << "\n"; } else { // array%component[(indices) substring|complex part] cases. // Component ref of array bases are dealt with below in embox/rebox. @@ -487,6 +532,7 @@ class DesignateOpConversion } if (mlir::isa(designateResultType)) { + llvm::dbgs() << "mlir::isa(designateResultType)" << __LINE__ << "\n"; // Generate embox or rebox. mlir::Type eleTy = fir::unwrapPassByRefType(designateResultType); bool isScalarDesignator = !mlir::isa(eleTy); @@ -495,6 +541,7 @@ class DesignateOpConversion // The base box will be used for emboxing the scalar element. sourceBox = base; // Generate the coordinate of the element. + llvm::dbgs() << "isScalarDesignator\n"; base = genSubscriptBeginAddr(builder, loc, designate, baseEleTy, base, shape, firBaseTypeParameters); shape = nullptr; @@ -581,8 +628,10 @@ class DesignateOpConversion // shape. The base may be an array, or a scalar. mlir::Type resultAddressType = designateResultType; if (auto boxCharType = - mlir::dyn_cast(designateResultType)) + mlir::dyn_cast(designateResultType)) { + llvm::dbgs() << "mlir::dyn_cast(designateResultType)" << __LINE__ << "\n"; resultAddressType = fir::ReferenceType::get(boxCharType.getEleTy()); + } // Array element indexing. if (!designate.getIndices().empty()) { @@ -590,8 +639,11 @@ class DesignateOpConversion // - scalar%array_comp(indices) [substring|complex_part] // This may be a ranked contiguous array section in which case // The first element address is being computed. + llvm::dbgs() << "!designate.getIndices().empty()" << __LINE__ << "\n"; + llvm::dbgs() << "base before genSubscriptBeginAddr: " << __LINE__ << " " << base << "\n"; base = genSubscriptBeginAddr(builder, loc, designate, baseEleTy, base, shape, firBaseTypeParameters); + llvm::dbgs() << "base after genSubscriptBeginAddr: " << __LINE__ << base << "\n"; } // Scalar substring (potentially on the previously built array element or @@ -602,6 +654,7 @@ class DesignateOpConversion // Scalar complex part ref if (designate.getComplexPart()) { + llvm::dbgs() << "designate.getComplexPart()" << __LINE__ << "\n"; // Sequence types should have already been handled by this point assert(!mlir::isa(designateResultType)); auto index = builder.createIntegerConstant(loc, builder.getIndexType(), @@ -612,13 +665,18 @@ class DesignateOpConversion // Cast/embox the computed scalar address if needed. if (mlir::isa(designateResultType)) { + llvm::dbgs() << "mlir::isa(designateResultType)" << __LINE__ << "\n"; assert(designate.getTypeparams().size() == 1 && "must have character length"); auto emboxChar = builder.create( loc, designateResultType, base, designate.getTypeparams()[0]); rewriter.replaceOp(designate, emboxChar.getResult()); } else { + llvm::dbgs() << "createConvert" << __LINE__ << "\n"; + llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; + llvm::dbgs() << "base: " << base << "\n"; base = builder.createConvert(loc, designateResultType, base); + llvm::dbgs() << "base after conversion: " << base << "\n"; rewriter.replaceOp(designate, base); } return mlir::success(); From d10cfccbab9586ee56aec27626727796d3d8e7c3 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 18 Mar 2025 13:56:14 -0700 Subject: [PATCH 2/5] cleanup --- .../flang/Optimizer/Builder/HLFIRTools.h | 1 - .../include/flang/Optimizer/CodeGen/CGOps.td | 8 --- .../include/flang/Optimizer/HLFIR/HLFIROps.td | 3 +- flang/lib/Optimizer/Builder/HLFIRTools.cpp | 4 -- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 11 ---- flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 23 +------- .../HLFIR/Transforms/ConvertToFIR.cpp | 56 +------------------ 7 files changed, 5 insertions(+), 101 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 0128e8f558f0b..4d4d4e7c0e895 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -51,7 +51,6 @@ inline bool isFortranEntityWithAttributes(mlir::Value value) { class Entity : public mlir::Value { public: explicit Entity(mlir::Value value) : mlir::Value(value) { - llvm::dbgs() << value << " " << value.getType() << "\n"; assert(isFortranEntity(value) && "must be a value representing a Fortran value or variable like"); } diff --git a/flang/include/flang/Optimizer/CodeGen/CGOps.td b/flang/include/flang/Optimizer/CodeGen/CGOps.td index adb20f8b7264b..a0b92ae97df5c 100644 --- a/flang/include/flang/Optimizer/CodeGen/CGOps.td +++ b/flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -30,14 +30,6 @@ def fircg_Dialect : Dialect { class fircg_Op traits> : Op; -def fircg_VolatileLoadOp : fircg_Op<"volatile_load", []> { - let summary = "for internal conversion only"; - let description = [{ - arg must be ref of volatile - }]; - let arguments = (ins Arg:$memref); -} - // Extended embox operation. def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> { let summary = "for internal conversion only"; diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index 5935d0b356d38..cdd2b776fd186 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -124,8 +124,7 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments, /// Given a FIR memory type, and information about non default lower /// bounds, get the related HLFIR variable type. - static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds); - static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds, bool isVolatile); + static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds, bool isVolatile=false); }]; let hasVerifier = 1; diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index e1d1f94e49edc..0d076ac59c8bc 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -751,19 +751,15 @@ std::pair hlfir::genVariableFirBaseShapeAndParams( mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, llvm::SmallVectorImpl &typeParams) { auto [exv, cleanup] = translateToExtendedValue(loc, builder, entity); - llvm::dbgs() << "exv: " << exv << "\n"; assert(!cleanup && "variable to Exv should not produce cleanup"); if (entity.hasLengthParameters()) { - llvm::dbgs() << "entity.hasLengthParameters()\n"; auto params = fir::getTypeParams(exv); typeParams.append(params.begin(), params.end()); } if (entity.isScalar()) { - llvm::dbgs() << "entity.isScalar()\n"; return {fir::getBase(exv), mlir::Value{}}; } if (auto variableInterface = entity.getIfVariableInterface()) { - llvm::dbgs() << "variableInterface: " << variableInterface << "\n"; return {fir::getBase(exv), asEmboxShape(loc, builder, exv, variableInterface.getShape())}; } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index c3613832d2872..9d640f936a268 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -699,17 +699,11 @@ struct ConvertOpConversion : public fir::FIROpConversion { llvm::LogicalResult matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - llvm::dbgs() << "ConvertOpConversion\n"; auto fromFirTy = convert.getValue().getType(); auto toFirTy = convert.getRes().getType(); - llvm::dbgs() << "fromFirTy: " << fromFirTy << "\n"; - llvm::dbgs() << "toFirTy: " << toFirTy << "\n"; auto fromTy = convertType(fromFirTy); auto toTy = convertType(toFirTy); - llvm::dbgs() << "fromTy: " << fromTy << "\n"; - llvm::dbgs() << "toTy: " << toTy << "\n"; mlir::Value op0 = adaptor.getOperands()[0]; - llvm::dbgs() << "op0: " << op0 << "\n"; if (fromFirTy == toFirTy) { rewriter.replaceOp(convert, op0); return mlir::success(); @@ -3224,7 +3218,6 @@ struct LoadOpConversion : public fir::FIROpConversion { mlir::Type originalLoadTy = load.getMemref().getType(); const bool isVolatile = mlir::isa(originalLoadTy); - auto volatileAttr = mlir::UnitAttr::get(load.getContext()); mlir::Type llvmLoadTy = convertObjectType(load.getType()); if (auto boxTy = mlir::dyn_cast(load.getType())) { // fir.box is a special case because it is considered an ssa value in @@ -3544,9 +3537,6 @@ struct StoreOpConversion : public fir::FIROpConversion { mlir::Value llvmValue = adaptor.getValue(); mlir::Value llvmMemref = adaptor.getMemref(); mlir::LLVM::AliasAnalysisOpInterface newOp; - llvm::dbgs() << "storeTy: " << storeTy << "\n"; - llvm::dbgs() << "originalStoreTy: " << originalStoreTy << "\n"; - llvm::dbgs() << "isVolatile: " << isVolatile << "\n"; if (auto boxTy = mlir::dyn_cast(storeTy)) { mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy); // Always use memcpy because LLVM is not as effective at optimizing @@ -3559,7 +3549,6 @@ struct StoreOpConversion : public fir::FIROpConversion { } else { newOp = rewriter.create(loc, llvmValue, llvmMemref, 0, isVolatile, false); } - llvm::dbgs() << "newOp: " << newOp << "\n"; if (std::optional optionalTag = store.getTbaa()) newOp.setTBAATags(*optionalTag); else diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index d63a2ac39f6b5..40b68633f851c 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -179,27 +179,6 @@ void hlfir::AssignOp::getEffects( /// Given a FIR memory type, and information about non default lower bounds, get /// the related HLFIR variable type. -mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, - bool hasExplicitLowerBounds) { - mlir::Type type = fir::unwrapRefType(inputType); - if (mlir::isa(type)) - return inputType; - if (auto charType = mlir::dyn_cast(type)) - if (charType.hasDynamicLen()) - return fir::BoxCharType::get(charType.getContext(), charType.getFKind()); - - auto seqType = mlir::dyn_cast(type); - bool hasDynamicExtents = - seqType && fir::sequenceWithNonConstantShape(seqType); - mlir::Type eleType = seqType ? seqType.getEleTy() : type; - bool hasDynamicLengthParams = fir::characterWithDynamicLen(eleType) || - fir::isRecordWithTypeParameters(eleType); - if (hasExplicitLowerBounds || hasDynamicExtents || hasDynamicLengthParams) - return fir::BoxType::get(type); - return inputType; -} - -// Updated version with volatile support mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, bool hasExplicitLowerBounds, bool isVolatile) { @@ -224,7 +203,7 @@ mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, auto refType = mlir::cast(inputType); return fir::VolatileReferenceType::get(refType.getEleTy()); } - + return inputType; } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index d591ff67db2d4..9f4c9e5304aee 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -295,7 +295,6 @@ class DeclareOpConversion : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { mlir::Location loc = declareOp->getLoc(); mlir::Value memref = declareOp.getMemref(); - mlir::Type memrefType = memref.getType(); fir::FortranVariableFlagsAttr fortranAttrs; cuf::DataAttributeAttr dataAttr; if (auto attrs = declareOp.getFortranAttrs()) @@ -303,11 +302,8 @@ class DeclareOpConversion : public mlir::OpRewritePattern { fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs); if (auto attr = declareOp.getDataAttr()) dataAttr = cuf::DataAttributeAttr::get(rewriter.getContext(), *attr); - if (auto volType = mlir::dyn_cast(declareOp.getResult(0).getType())) { - memrefType = volType; - } auto firDeclareOp = rewriter.create( - loc, memrefType, memref, declareOp.getShape(), + loc, memref.getType(), memref, declareOp.getShape(), declareOp.getTypeparams(), declareOp.getDummyScope(), declareOp.getUniqName(), fortranAttrs, dataAttr); @@ -389,19 +385,12 @@ class DeclareOpConversion : public mlir::OpRewritePattern { hlfirBase = rewriter.create( loc, hlfirBaseType, firBase, declareOp.getTypeparams()[0]); } else { - llvm::dbgs() << "hlfirBaseType: " << hlfirBaseType << "\n"; - llvm::dbgs() << "firBase.getType(): " << firBase.getType() << "\n"; - auto volType = mlir::dyn_cast(hlfirBaseType); - auto refType = mlir::dyn_cast(firBase.getType()); - if (volType && refType && volType.getEleTy() == refType.getEleTy()) { - firBase = hlfirBase; - } else if (hlfirBaseType != firBase.getType()) { + if (hlfirBaseType != firBase.getType()) { declareOp.emitOpError() << "unhandled HLFIR variable type '" << hlfirBaseType << "' does not match fir type '" << firBase.getType() << "' with memref '" << memref << "'\n"; return mlir::failure(); - } else { - hlfirBase = firBase; } + hlfirBase = firBase; } rewriter.replaceOp(declareOp, {hlfirBase, firBase}); return mlir::success(); @@ -419,10 +408,6 @@ class DesignateOpConversion mlir::Value shape, const llvm::SmallVector &firBaseTypeParameters) { assert(!designate.getIndices().empty()); - llvm::dbgs() << "genSubscriptBeginAddr\n"; - llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; - llvm::dbgs() << "base: " << base << "\n"; - llvm::dbgs() << "shape: " << shape << "\n"; if (auto decl = mlir::dyn_cast(base.getDefiningOp())) { base = decl.getResult(0); } @@ -436,19 +421,13 @@ class DesignateOpConversion firstElementIndices.push_back(indices[i]); i = i + (isTriplet ? 3 : 1); } - llvm::dbgs() << "building array coor\n"; auto designateResultType = designate.getResult().getType(); - llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; auto isVolatile = mlir::isa(designateResultType); - llvm::dbgs() << "isVolatile: " << isVolatile << "\n"; mlir::Type refTy = fir::ReferenceType::get(baseEleTy); mlir::Type volTy = fir::VolatileReferenceType::get(baseEleTy); - llvm::dbgs() << "refTy: " << refTy << "\n"; - llvm::dbgs() << "volTy: " << volTy << "\n"; base = builder.create( loc, isVolatile ? volTy : refTy, base, shape, /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters); - llvm::dbgs() << "base: " << base << "\n"; return base; } @@ -459,35 +438,25 @@ class DesignateOpConversion llvm::LogicalResult matchAndRewrite(hlfir::DesignateOp designate, mlir::PatternRewriter &rewriter) const override { - llvm::dbgs() << "DesignateOpConversion\n"; mlir::Location loc = designate.getLoc(); fir::FirOpBuilder builder(rewriter, designate.getOperation()); hlfir::Entity baseEntity(designate.getMemref()); bool isVolatile = mlir::isa(designate.getResult().getType()); - llvm::dbgs() << "baseEntity: " << baseEntity << "\n"; mlir::Type baseType = baseEntity.getBase().getType(); - llvm::dbgs() << "baseType: " << baseType << "\n"; if (baseEntity.isMutableBox()) TODO(loc, "hlfir::designate load of pointer or allocatable"); mlir::Type designateResultType = designate.getResult().getType(); - llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; llvm::SmallVector firBaseTypeParameters; auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams( loc, builder, baseEntity, firBaseTypeParameters); mlir::Type baseEleTy = hlfir::getFortranElementType(base.getType()); mlir::Type resultEleTy = hlfir::getFortranElementType(designateResultType); - llvm::dbgs() << "base: " << base << " " << base.getType() << "\n"; - llvm::dbgs() << "shape: " << shape << " " << shape.getType() << "\n"; - llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; - llvm::dbgs() << "resultEleTy: " << resultEleTy << "\n"; mlir::Value fieldIndex; if (designate.getComponent()) { - llvm::dbgs() << "designate.getComponent(): " << designate.getComponent() << "\n"; mlir::Type baseRecordType = baseEntity.getFortranElementType(); - llvm::dbgs() << "baseRecordType: " << baseRecordType << "\n"; if (fir::isRecordWithTypeParameters(baseRecordType)) TODO(loc, "hlfir.designate with a parametrized derived type base"); fieldIndex = builder.create( @@ -495,7 +464,6 @@ class DesignateOpConversion designate.getComponent().value(), baseRecordType, /*typeParams=*/mlir::ValueRange{}); if (baseEntity.isScalar()) { - llvm::dbgs() << "baseEntity.isScalar()\n"; // Component refs of scalar base right away: // - scalar%scalar_component [substring|complex_part] or // - scalar%static_size_array_comp @@ -503,13 +471,9 @@ class DesignateOpConversion mlir::Type componentType = mlir::cast(baseEleTy).getType( designate.getComponent().value()); - llvm::dbgs() << "componentType: " << componentType << "\n"; mlir::Type coorTy = fir::ReferenceType::get(componentType); - llvm::dbgs() << "coorTy: " << coorTy << "\n"; base = builder.create(loc, coorTy, base, fieldIndex); - llvm::dbgs() << "base: " << base << "\n"; if (mlir::isa(componentType)) { - llvm::dbgs() << "mlir::isa(componentType)\n"; auto variableInterface = mlir::cast( designate.getOperation()); if (variableInterface.isAllocatable() || @@ -522,8 +486,6 @@ class DesignateOpConversion } baseEleTy = hlfir::getFortranElementType(componentType); shape = designate.getComponentShape(); - llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; - llvm::dbgs() << "shape: " << shape << "\n"; } else { // array%component[(indices) substring|complex part] cases. // Component ref of array bases are dealt with below in embox/rebox. @@ -532,7 +494,6 @@ class DesignateOpConversion } if (mlir::isa(designateResultType)) { - llvm::dbgs() << "mlir::isa(designateResultType)" << __LINE__ << "\n"; // Generate embox or rebox. mlir::Type eleTy = fir::unwrapPassByRefType(designateResultType); bool isScalarDesignator = !mlir::isa(eleTy); @@ -541,7 +502,6 @@ class DesignateOpConversion // The base box will be used for emboxing the scalar element. sourceBox = base; // Generate the coordinate of the element. - llvm::dbgs() << "isScalarDesignator\n"; base = genSubscriptBeginAddr(builder, loc, designate, baseEleTy, base, shape, firBaseTypeParameters); shape = nullptr; @@ -629,7 +589,6 @@ class DesignateOpConversion mlir::Type resultAddressType = designateResultType; if (auto boxCharType = mlir::dyn_cast(designateResultType)) { - llvm::dbgs() << "mlir::dyn_cast(designateResultType)" << __LINE__ << "\n"; resultAddressType = fir::ReferenceType::get(boxCharType.getEleTy()); } @@ -639,11 +598,8 @@ class DesignateOpConversion // - scalar%array_comp(indices) [substring|complex_part] // This may be a ranked contiguous array section in which case // The first element address is being computed. - llvm::dbgs() << "!designate.getIndices().empty()" << __LINE__ << "\n"; - llvm::dbgs() << "base before genSubscriptBeginAddr: " << __LINE__ << " " << base << "\n"; base = genSubscriptBeginAddr(builder, loc, designate, baseEleTy, base, shape, firBaseTypeParameters); - llvm::dbgs() << "base after genSubscriptBeginAddr: " << __LINE__ << base << "\n"; } // Scalar substring (potentially on the previously built array element or @@ -654,7 +610,6 @@ class DesignateOpConversion // Scalar complex part ref if (designate.getComplexPart()) { - llvm::dbgs() << "designate.getComplexPart()" << __LINE__ << "\n"; // Sequence types should have already been handled by this point assert(!mlir::isa(designateResultType)); auto index = builder.createIntegerConstant(loc, builder.getIndexType(), @@ -665,18 +620,13 @@ class DesignateOpConversion // Cast/embox the computed scalar address if needed. if (mlir::isa(designateResultType)) { - llvm::dbgs() << "mlir::isa(designateResultType)" << __LINE__ << "\n"; assert(designate.getTypeparams().size() == 1 && "must have character length"); auto emboxChar = builder.create( loc, designateResultType, base, designate.getTypeparams()[0]); rewriter.replaceOp(designate, emboxChar.getResult()); } else { - llvm::dbgs() << "createConvert" << __LINE__ << "\n"; - llvm::dbgs() << "designateResultType: " << designateResultType << "\n"; - llvm::dbgs() << "base: " << base << "\n"; base = builder.createConvert(loc, designateResultType, base); - llvm::dbgs() << "base after conversion: " << base << "\n"; rewriter.replaceOp(designate, base); } return mlir::success(); From 427c3d489bd0a63d5d265bae5c387856ad3d56eb Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 18 Mar 2025 14:01:11 -0700 Subject: [PATCH 3/5] cleanup --- .../flang/Optimizer/Builder/HLFIRTools.h | 3 +- flang/lib/Lower/ConvertExprToHLFIR.cpp | 11 ++--- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 13 +++-- flang/lib/Optimizer/Dialect/FIROps.cpp | 10 ++-- flang/lib/Optimizer/Dialect/FIRType.cpp | 49 +++++++++++-------- flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 3 +- flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 21 +++++--- .../HLFIR/Transforms/ConvertToFIR.cpp | 9 ++-- 8 files changed, 67 insertions(+), 52 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 4d4d4e7c0e895..6390a3a926fba 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -65,7 +65,8 @@ class Entity : public mlir::Value { bool isVolatile() const { if (auto iface = getIfVariableInterface()) { if (auto attrs = iface.getFortranAttrs()) { - return bitEnumContainsAny(attrs.value(), fir::FortranVariableFlagsEnum::fortran_volatile); + return bitEnumContainsAny( + attrs.value(), fir::FortranVariableFlagsEnum::fortran_volatile); } } return false; diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index 23d938d164bef..5fb8952fcd0ee 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -36,11 +36,6 @@ namespace { -/// Determine if a given symbol or designator has the VOLATILE attribute. -static bool isVolatileDesignator(const Fortran::semantics::Symbol &symbol) { - return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); -} - /// Determine if a given symbol has the VOLATILE attribute. static bool isVolatileSymbol(const Fortran::semantics::Symbol &symbol) { return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE); @@ -235,10 +230,12 @@ class HlfirDesignatorBuilder { return fir::BoxType::get(resultValueType); // Check if this should be a volatile reference - if constexpr (std::is_same_v, Fortran::evaluate::SymbolRef>) { + if constexpr (std::is_same_v, + Fortran::evaluate::SymbolRef>) { if (isVolatileSymbol(designatorNode.get())) return fir::VolatileReferenceType::get(resultValueType); - } else if constexpr (std::is_same_v, Fortran::evaluate::Component>) { + } else if constexpr (std::is_same_v, + Fortran::evaluate::Component>) { if (isVolatileSymbol(designatorNode.GetLastSymbol())) return fir::VolatileReferenceType::get(resultValueType); } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 9d640f936a268..f7ed8ce7c6d54 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -3217,7 +3217,8 @@ struct LoadOpConversion : public fir::FIROpConversion { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type originalLoadTy = load.getMemref().getType(); - const bool isVolatile = mlir::isa(originalLoadTy); + const bool isVolatile = + mlir::isa(originalLoadTy); mlir::Type llvmLoadTy = convertObjectType(load.getType()); if (auto boxTy = mlir::dyn_cast(load.getType())) { // fir.box is a special case because it is considered an ssa value in @@ -3533,7 +3534,8 @@ struct StoreOpConversion : public fir::FIROpConversion { mlir::Location loc = store.getLoc(); mlir::Type storeTy = store.getValue().getType(); mlir::Type originalStoreTy = store.getMemref().getType(); - const bool isVolatile = mlir::isa(originalStoreTy); + const bool isVolatile = + mlir::isa(originalStoreTy); mlir::Value llvmValue = adaptor.getValue(); mlir::Value llvmMemref = adaptor.getMemref(); mlir::LLVM::AliasAnalysisOpInterface newOp; @@ -3544,10 +3546,11 @@ struct StoreOpConversion : public fir::FIROpConversion { TypePair boxTypePair{boxTy, llvmBoxTy}; mlir::Value boxSize = computeBoxSize(loc, boxTypePair, llvmValue, rewriter); - newOp = rewriter.create( - loc, llvmMemref, llvmValue, boxSize, isVolatile); + newOp = rewriter.create(loc, llvmMemref, llvmValue, + boxSize, isVolatile); } else { - newOp = rewriter.create(loc, llvmValue, llvmMemref, 0, isVolatile, false); + newOp = rewriter.create(loc, llvmValue, llvmMemref, + 0, isVolatile, false); } if (std::optional optionalTag = store.getTbaa()) newOp.setTBAATags(*optionalTag); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 8ec5c8fee69c7..632eb4f2ea612 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -305,8 +305,8 @@ static mlir::Type wrapAllocMemResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value - if (mlir::isa(intype)) + if (mlir::isa(intype)) return {}; return fir::HeapType::get(intype); } @@ -1368,9 +1368,9 @@ bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { return mlir::isa(ty); + fir::PointerType, fir::HeapType, fir::LLVMPointerType, + mlir::MemRefType, mlir::FunctionType, fir::TypeDescType, + mlir::LLVM::LLVMPointerType>(ty); } static std::optional getVectorElementType(mlir::Type ty) { diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index ea663cc42e647..7896c75aef60b 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -217,15 +217,17 @@ mlir::Type getDerivedType(mlir::Type ty) { mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) + .Case( + [](auto p) { return p.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case([](auto p) { return p.getEleTy(); }) + .Case( + [](auto p) { return p.getEleTy(); }) .Case( [](auto p) { return unwrapRefType(p.getEleTy()); }) .Default([](mlir::Type) { return mlir::Type{}; }); @@ -596,7 +598,8 @@ std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap, } else if (auto refTy = mlir::dyn_cast_or_null(ty)) { name << "ref_"; ty = refTy.getEleTy(); - } else if (auto refTy = mlir::dyn_cast_or_null(ty)) { + } else if (auto refTy = + mlir::dyn_cast_or_null(ty)) { name << "volatile_ref_"; ty = refTy.getEleTy(); } else if (auto ptrTy = mlir::dyn_cast_or_null(ty)) { @@ -652,12 +655,13 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, .Case([&](fir::SequenceType seqTy) -> mlir::Type { return fir::SequenceType::get(seqTy.getShape(), newElementType); }) - .Case([&](auto t) -> mlir::Type { - using FIRT = decltype(t); - return FIRT::get( - changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass)); - }) + .Case( + [&](auto t) -> mlir::Type { + using FIRT = decltype(t); + return FIRT::get(changeElementType(t.getEleTy(), newElementType, + turnBoxIntoClass)); + }) .Case([&](fir::BoxType t) -> mlir::Type { mlir::Type newInnerType = changeElementType(t.getEleTy(), newElementType, false); @@ -1153,7 +1157,8 @@ llvm::LogicalResult fir::SequenceType::verify( // DIMENSION attribute can only be applied to an intrinsic or record type if (mlir::isa(eleTy)) + ReferenceType, VolatileReferenceType, TypeDescType, + SequenceType>(eleTy)) return emitError() << "cannot build an array of this element type: " << eleTy << '\n'; return mlir::success(); @@ -1325,11 +1330,12 @@ changeTypeShape(mlir::Type type, return fir::SequenceType::get(*newShape, seqTy.getEleTy()); return seqTy.getEleTy(); }) - .Case([&](auto t) -> mlir::Type { - using FIRT = decltype(t); - return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); - }) + .Case( + [&](auto t) -> mlir::Type { + using FIRT = decltype(t); + return FIRT::get(changeTypeShape(t.getEleTy(), newShape)); + }) .Default([&](mlir::Type t) -> mlir::Type { assert((fir::isa_trivial(t) || llvm::isa(t) || llvm::isa(t)) && @@ -1399,8 +1405,9 @@ void FIROpsDialect::registerTypes() { addTypes(); + VolatileReferenceType, SequenceType, ShapeType, ShapeShiftType, + ShiftType, SliceType, TypeDescType, fir::VectorType, + fir::DummyScopeType>(); fir::ReferenceType::attachInterface< OpenMPPointerLikeModel>(*getContext()); fir::VolatileReferenceType::attachInterface< @@ -1476,7 +1483,7 @@ fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty, TODO(loc, "computing size of a component"); } -llvm::LogicalResult -fir::VolatileReferenceType::verify(llvm::function_ref, mlir::Type) { +llvm::LogicalResult fir::VolatileReferenceType::verify( + llvm::function_ref, mlir::Type) { return mlir::success(); } diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index c2a485c66e0ec..827564aa60471 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -82,7 +82,8 @@ void hlfir::ExprType::print(mlir::AsmPrinter &printer) const { bool hlfir::isFortranVariableType(mlir::Type type) { return llvm::TypeSwitch(type) - .Case([](auto p) { + .Case([](auto p) { mlir::Type eleType = p.getEleTy(); return mlir::isa(eleType) || !fir::hasDynamicSize(eleType); diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 40b68633f851c..f7abeb202a6a8 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -198,7 +198,8 @@ mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType, if (hasExplicitLowerBounds || hasDynamicExtents || hasDynamicLengthParams) return fir::BoxType::get(type); - // If this is a reference type and has the volatile attribute, use VolatileReferenceType + // If this is a reference type and has the volatile attribute, use + // VolatileReferenceType if (isVolatile && mlir::isa(inputType)) { auto refType = mlir::cast(inputType); return fir::VolatileReferenceType::get(refType.getEleTy()); @@ -224,7 +225,8 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, bool hasExplicitLbs = hasExplicitLowerBounds(shape); bool isVolatile = false; if (fortran_attrs && mlir::isa(inputType) && - bitEnumContainsAny(fortran_attrs.getFlags(), fir::FortranVariableFlagsEnum::fortran_volatile)) { + bitEnumContainsAny(fortran_attrs.getFlags(), + fir::FortranVariableFlagsEnum::fortran_volatile)) { auto refType = mlir::cast(inputType); isVolatile = true; inputType = fir::VolatileReferenceType::get(refType.getEleTy()); @@ -232,16 +234,18 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, } mlir::Type hlfirVariableType = getHLFIRVariableType(inputType, hasExplicitLbs, isVolatile); - + build(builder, result, {hlfirVariableType, inputType}, memref, shape, typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr); } -static bool hlfirVariableTypeCompatible(mlir::Type memrefType, mlir::Type outputType) { - // if the input and output types don't match, they are still compatible ONLY if this is - // due to the variable being declared volatile. +static bool hlfirVariableTypeCompatible(mlir::Type memrefType, + mlir::Type outputType) { + // if the input and output types don't match, they are still compatible ONLY + // if this is due to the variable being declared volatile. if (auto inputRefTy = mlir::dyn_cast(memrefType)) { - if (auto hlfirRefTy = mlir::dyn_cast(outputType)) { + if (auto hlfirRefTy = + mlir::dyn_cast(outputType)) { return hlfirRefTy.getEleTy() == inputRefTy.getEleTy(); } } @@ -255,7 +259,8 @@ llvm::LogicalResult hlfir::DeclareOp::verify() { bool isVolatile = false; if (getFortranAttrs().has_value()) { auto flagsEnum = getFortranAttrs().value(); - isVolatile = bitEnumContainsAny(flagsEnum, fir::FortranVariableFlagsEnum::fortran_volatile); + isVolatile = bitEnumContainsAny( + flagsEnum, fir::FortranVariableFlagsEnum::fortran_volatile); attrs = fir::FortranVariableFlagsAttr::get(getContext(), flagsEnum); } mlir::Type hlfirVariableType = getHLFIRVariableType( diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 9f4c9e5304aee..1dc6e659f6136 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -387,7 +387,9 @@ class DeclareOpConversion : public mlir::OpRewritePattern { } else { if (hlfirBaseType != firBase.getType()) { declareOp.emitOpError() - << "unhandled HLFIR variable type '" << hlfirBaseType << "' does not match fir type '" << firBase.getType() << "' with memref '" << memref << "'\n"; + << "unhandled HLFIR variable type '" << hlfirBaseType + << "' does not match fir type '" << firBase.getType() + << "' with memref '" << memref << "'\n"; return mlir::failure(); } hlfirBase = firBase; @@ -422,7 +424,8 @@ class DesignateOpConversion i = i + (isTriplet ? 3 : 1); } auto designateResultType = designate.getResult().getType(); - auto isVolatile = mlir::isa(designateResultType); + auto isVolatile = + mlir::isa(designateResultType); mlir::Type refTy = fir::ReferenceType::get(baseEleTy); mlir::Type volTy = fir::VolatileReferenceType::get(baseEleTy); base = builder.create( @@ -442,8 +445,6 @@ class DesignateOpConversion fir::FirOpBuilder builder(rewriter, designate.getOperation()); hlfir::Entity baseEntity(designate.getMemref()); - bool isVolatile = mlir::isa(designate.getResult().getType()); - mlir::Type baseType = baseEntity.getBase().getType(); if (baseEntity.isMutableBox()) TODO(loc, "hlfir::designate load of pointer or allocatable"); From ff158ab414115c6512c5722f061ec5263d77acca Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 18 Mar 2025 15:23:31 -0700 Subject: [PATCH 4/5] cleanup --- .../flang/Optimizer/Dialect/FIRTypes.td | 39 ++++++++++++------- flang/lib/Optimizer/Dialect/FIROps.cpp | 22 +++++------ .../HLFIR/Transforms/ConvertToFIR.cpp | 6 +-- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index 2eb0e5ba1af80..da2de318ae8b3 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -390,8 +390,8 @@ def fir_VolatileReferenceType : FIR_Type<"VolatileReference", "volatile_ref"> { let parameters = (ins "mlir::Type":$eleTy); - let builders = [ - TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + let builders = [TypeBuilderWithInferredContext< + (ins "mlir::Type":$elementType), [{ return Base::get(elementType.getContext(), elementType); }]>, ]; @@ -622,20 +622,27 @@ def AnyCompositeLike : TypeConstraint; def AnyReferenceType : TypeConstraint, "any reference type">; + fir_VolatileReferenceType.predicate]>, + "any reference type">; // Reference types -def AnyReferenceLike : TypeConstraint, "any reference">; +def AnyReferenceLike + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_HeapType.predicate, fir_PointerType.predicate, + fir_LLVMPointerType.predicate]>, + "any reference">; def FuncType : TypeConstraint; def AnyCodeOrDataRefLike : TypeConstraint, "any code or data reference">; -def RefOrLLVMPtr : TypeConstraint, "fir.ref or fir.llvm_ptr">; +def RefOrLLVMPtr + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_LLVMPointerType.predicate]>, + "fir.ref or fir.llvm_ptr">; def AnyBoxLike : TypeConstraint, "any reference or box like">; -def AnyRefOrBox : TypeConstraint, "any reference or box">; +def AnyRefOrBox + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + fir_HeapType.predicate, fir_PointerType.predicate, + IsBaseBoxTypePred]>, + "any reference or box">; def AnyRefOrBoxType : Type; def AnyShapeLike : TypeConstraint; // The legal types of global symbols -def AnyAddressableLike : TypeConstraint, "any addressable">; +def AnyAddressableLike + : TypeConstraint< + Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate, + FunctionType.predicate]>, + "any addressable">; def ArrayOrBoxOrRecord : TypeConstraint, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 632eb4f2ea612..2cc7d9359f94d 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -441,8 +441,9 @@ llvm::LogicalResult fir::ArrayCoorOp::verify() { if (sliceTy.getRank() != arrDim) return emitOpError("rank of dimension in slice mismatched"); } - if (!validTypeParams(getMemref().getType(), getTypeparams())) + if (!validTypeParams(getMemref().getType(), getTypeparams())) { return emitOpError("invalid type parameters"); + } return mlir::success(); } @@ -823,18 +824,15 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// static mlir::Type adjustedElementType(mlir::Type t) { - mlir::Type eleTy; - if (auto ty = mlir::dyn_cast(t)) { - eleTy = ty.getEleTy(); - } else if (auto volType = mlir::dyn_cast(t)) { - eleTy = volType.getEleTy(); + if (fir::isa_ref_type(t)) { + mlir::Type eleTy = fir::dyn_cast_ptrEleTy(t); + if (fir::isa_char(eleTy)) + return eleTy; + if (fir::isa_derived(eleTy)) + return eleTy; + if (mlir::isa(eleTy)) + return eleTy; } - if (fir::isa_char(eleTy)) - return eleTy; - if (fir::isa_derived(eleTy)) - return eleTy; - if (mlir::isa(eleTy)) - return eleTy; return t; } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 1dc6e659f6136..6567afe0188d0 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -410,9 +410,6 @@ class DesignateOpConversion mlir::Value shape, const llvm::SmallVector &firBaseTypeParameters) { assert(!designate.getIndices().empty()); - if (auto decl = mlir::dyn_cast(base.getDefiningOp())) { - base = decl.getResult(0); - } llvm::SmallVector firstElementIndices; auto indices = designate.getIndices(); int i = 0; @@ -428,6 +425,9 @@ class DesignateOpConversion mlir::isa(designateResultType); mlir::Type refTy = fir::ReferenceType::get(baseEleTy); mlir::Type volTy = fir::VolatileReferenceType::get(baseEleTy); + llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; + llvm::dbgs() << "refTy: " << refTy << "\n"; + llvm::dbgs() << "volTy: " << volTy << "\n"; base = builder.create( loc, isVolatile ? volTy : refTy, base, shape, /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters); From f07e68115390faea0d3cb38de5ceda9cfc8a89b9 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Wed, 19 Mar 2025 11:58:43 -0700 Subject: [PATCH 5/5] checkpoint --- flang/include/flang/Optimizer/Dialect/FIRTypes.td | 1 + flang/lib/Optimizer/Dialect/FIRType.cpp | 4 ++++ flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp | 10 +++------- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index da2de318ae8b3..cd40b6579cf4c 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -375,6 +375,7 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> { let extraClassDeclaration = [{ mlir::Type getElementType() const { return getEleTy(); } + static mlir::Type get(mlir::Type t, bool isVolatile); }]; let genVerifyDecl = 1; diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index 7896c75aef60b..90614f1305c27 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -1076,6 +1076,10 @@ void fir::ReferenceType::print(mlir::AsmPrinter &printer) const { printer << "<" << getEleTy() << '>'; } +mlir::Type fir::ReferenceType::get(mlir::Type t, bool isVolatile) { + return isVolatile ? (mlir::Type)fir::VolatileReferenceType::get(t) : (mlir::Type)fir::ReferenceType::get(t); +} + llvm::LogicalResult fir::ReferenceType::verify( llvm::function_ref emitError, mlir::Type eleTy) { diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 6567afe0188d0..98b64634f7d2f 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -423,14 +423,10 @@ class DesignateOpConversion auto designateResultType = designate.getResult().getType(); auto isVolatile = mlir::isa(designateResultType); - mlir::Type refTy = fir::ReferenceType::get(baseEleTy); - mlir::Type volTy = fir::VolatileReferenceType::get(baseEleTy); - llvm::dbgs() << "baseEleTy: " << baseEleTy << "\n"; - llvm::dbgs() << "refTy: " << refTy << "\n"; - llvm::dbgs() << "volTy: " << volTy << "\n"; + mlir::Type refTy = fir::ReferenceType::get(baseEleTy, isVolatile); base = builder.create( - loc, isVolatile ? volTy : refTy, base, shape, - /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters); + loc, refTy, base, shape, /*slice=*/mlir::Value{}, + firstElementIndices, firBaseTypeParameters); return base; }