diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index 1583cfb3f5b51..ddd4ef7114a63 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { mlir::Block *getAllocaBlock(); /// Safely create a reference type to the type `eleTy`. - mlir::Type getRefType(mlir::Type eleTy); + mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false); /// Create a sequence of `eleTy` with `rank` dimensions of unknown size. mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1); diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h index 76e0aa352bcd9..0dbff258aea86 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -221,6 +221,10 @@ inline bool isa_char_string(mlir::Type t) { /// (since they may hold one), and are not considered to be unknown size. bool isa_unknown_size_box(mlir::Type t); +/// Returns true iff `t` is a type capable of representing volatility and has +/// the volatile attribute set. +bool isa_volatile_type(mlir::Type t); + /// Returns true iff `t` is a fir.char type and has an unknown length. inline bool characterWithDynamicLen(mlir::Type t) { if (auto charTy = mlir::dyn_cast(t)) @@ -474,6 +478,10 @@ inline mlir::Type updateTypeForUnlimitedPolymorphic(mlir::Type ty) { return ty; } +/// Re-create the given type with the given volatility, if this is a type +/// that can represent volatility. +mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile); + /// Replace the element type of \p type by \p newElementType, preserving /// all other layers of the type (fir.ref/ptr/heap/array/box/class). /// If \p turnBoxIntoClass and the input is a fir.box, it will be turned into diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td index fd5bbbe44751f..84b3932ea75f6 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -77,24 +77,24 @@ def fir_BoxType : FIR_Type<"Box", "box", [], "BaseBoxType"> { to) whether the entity is an array, its size, or what type it has. }]; - let parameters = (ins "mlir::Type":$eleTy); + let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile); let skipDefaultBuilders = 1; let builders = [ TypeBuilderWithInferredContext<(ins - "mlir::Type":$eleTy), [{ - return Base::get(eleTy.getContext(), eleTy); + "mlir::Type":$eleTy, CArg<"bool", "false">:$isVolatile), [{ + return Base::get(eleTy.getContext(), eleTy, isVolatile); }]>, ]; let extraClassDeclaration = [{ mlir::Type getElementType() const { return getEleTy(); } + bool isVolatile() const { return getIsVolatile(); } }]; let genVerifyDecl = 1; - - let assemblyFormat = "`<` $eleTy `>`"; + let hasCustomAssemblyFormat = 1; } def fir_CharacterType : FIR_Type<"Character", "char"> { @@ -146,16 +146,20 @@ def fir_ClassType : FIR_Type<"Class", "class", [], "BaseBoxType"> { is equivalent to a fir.box type with a dynamic type. }]; - let parameters = (ins "mlir::Type":$eleTy); + let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile); let builders = [ - TypeBuilderWithInferredContext<(ins "mlir::Type":$eleTy), [{ - return $_get(eleTy.getContext(), eleTy); + TypeBuilderWithInferredContext<(ins "mlir::Type":$eleTy, CArg<"bool", "false">:$isVolatile), [{ + return $_get(eleTy.getContext(), eleTy, isVolatile); }]> ]; + let extraClassDeclaration = [{ + bool isVolatile() const { return getIsVolatile(); } + }]; + let genVerifyDecl = 1; - let assemblyFormat = "`<` $eleTy `>`"; + let hasCustomAssemblyFormat = 1; } def fir_FieldType : FIR_Type<"Field", "field"> { @@ -363,18 +367,19 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> { The type of a reference to an entity in memory. }]; - let parameters = (ins "mlir::Type":$eleTy); + let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile); let skipDefaultBuilders = 1; let builders = [ - TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ - return Base::get(elementType.getContext(), elementType); + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVolatile), [{ + return Base::get(elementType.getContext(), elementType, isVolatile); }]>, ]; let extraClassDeclaration = [{ mlir::Type getElementType() const { return getEleTy(); } + bool isVolatile() const { return getIsVolatile(); } }]; let genVerifyDecl = 1; diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index fdc155ef2ef18..7fc30ca125a87 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -105,9 +105,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, return modOp.lookupSymbol(name); } -mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) { +mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) { assert(!mlir::isa(eleTy) && "cannot be a reference type"); - return fir::ReferenceType::get(eleTy); + return fir::ReferenceType::get(eleTy, isVolatile); } mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) { diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index f3f969ba401e5..1df0ea93b759f 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -32,6 +32,21 @@ using namespace fir; namespace { +static llvm::StringRef getVolatileKeyword() { return "volatile"; } + +static mlir::ParseResult parseOptionalCommaAndKeyword(mlir::AsmParser &parser, + mlir::StringRef keyword, + bool &parsedKeyword) { + if (!parser.parseOptionalComma()) { + if (parser.parseKeyword(keyword)) + return mlir::failure(); + parsedKeyword = true; + return mlir::success(); + } + parsedKeyword = false; + return mlir::success(); +} + template TYPE parseIntSingleton(mlir::AsmParser &parser) { int kind = 0; @@ -215,6 +230,19 @@ mlir::Type getDerivedType(mlir::Type ty) { .Default([](mlir::Type t) { return t; }); } +mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile) { + // If we already have the volatility we asked for, return the type unchanged. + if (fir::isa_volatile_type(type) == isVolatile) + return type; + return mlir::TypeSwitch(type) + .Case( + [&](auto ty) -> mlir::Type { + using TYPE = decltype(ty); + return TYPE::get(ty.getEleTy(), isVolatile); + }) + .Default([&](mlir::Type t) -> mlir::Type { return t; }); +} + mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) .Case(t) + .Case( + [](auto t) { return t.isVolatile(); }) + .Default([](mlir::Type) { return false; }); +} + //===----------------------------------------------------------------------===// // BoxProcType //===----------------------------------------------------------------------===// @@ -738,9 +773,31 @@ static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) { // BoxType //===----------------------------------------------------------------------===// +// `box` `<` type (`, volatile` $volatile^)? `>` +mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) { + mlir::Type eleTy; + auto location = parser.getCurrentLocation(); + auto *context = parser.getContext(); + bool isVolatile = false; + if (parser.parseLess() || parser.parseType(eleTy)) + return {}; + if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile)) + return {}; + if (parser.parseGreater()) + return {}; + return parser.getChecked(location, context, eleTy, isVolatile); +} + +void fir::BoxType::print(mlir::AsmPrinter &printer) const { + printer << "<" << getEleTy(); + if (isVolatile()) + printer << ", " << getVolatileKeyword(); + printer << '>'; +} + llvm::LogicalResult fir::BoxType::verify(llvm::function_ref emitError, - mlir::Type eleTy) { + mlir::Type eleTy, bool isVolatile) { if (mlir::isa(eleTy)) return emitError() << "invalid element type\n"; // TODO @@ -807,9 +864,32 @@ void fir::CharacterType::print(mlir::AsmPrinter &printer) const { // ClassType //===----------------------------------------------------------------------===// +// `class` `<` type (`, volatile` $volatile^)? `>` +mlir::Type fir::ClassType::parse(mlir::AsmParser &parser) { + mlir::Type eleTy; + auto location = parser.getCurrentLocation(); + auto *context = parser.getContext(); + bool isVolatile = false; + if (parser.parseLess() || parser.parseType(eleTy)) + return {}; + if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile)) + return {}; + if (parser.parseGreater()) + return {}; + return parser.getChecked(location, context, eleTy, + isVolatile); +} + +void fir::ClassType::print(mlir::AsmPrinter &printer) const { + printer << "<" << getEleTy(); + if (isVolatile()) + printer << ", " << getVolatileKeyword(); + printer << '>'; +} + llvm::LogicalResult fir::ClassType::verify(llvm::function_ref emitError, - mlir::Type eleTy) { + mlir::Type eleTy, bool isVolatile) { if (mlir::isa` +// `ref` `<` type (`, volatile` $volatile^)? `>` mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) { - return parseTypeSingleton(parser); + auto location = parser.getCurrentLocation(); + auto *context = parser.getContext(); + mlir::Type eleTy; + bool isVolatile = false; + if (parser.parseLess() || parser.parseType(eleTy)) + return {}; + if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile)) + return {}; + if (parser.parseGreater()) + return {}; + return parser.getChecked(location, context, eleTy, + isVolatile); } void fir::ReferenceType::print(mlir::AsmPrinter &printer) const { - printer << "<" << getEleTy() << '>'; + printer << "<" << getEleTy(); + if (isVolatile()) + printer << ", " << getVolatileKeyword(); + printer << '>'; } llvm::LogicalResult fir::ReferenceType::verify( - llvm::function_ref emitError, - mlir::Type eleTy) { + llvm::function_ref emitError, mlir::Type eleTy, + bool isVolatile) { if (mlir::isa(eleTy)) return emitError() << "cannot build a reference to type: " << eleTy << '\n'; diff --git a/flang/test/Fir/invalid-types.fir b/flang/test/Fir/invalid-types.fir index f4505097086ad..a3dc9242c4eb3 100644 --- a/flang/test/Fir/invalid-types.fir +++ b/flang/test/Fir/invalid-types.fir @@ -6,8 +6,7 @@ func.func private @box3() -> !fir.boxproc<> // ----- -// expected-error@+2 {{expected non-function type}} -// expected-error@+1 {{failed to parse fir_BoxType parameter 'eleTy' which is to be a `mlir::Type`}} +// expected-error@+1 {{expected non-function type}} func.func private @box1() -> !fir.box<> // ----- @@ -105,6 +104,11 @@ func.func private @mem3() -> !fir.ref<> // ----- +// expected-error@+1 {{expected non-function type}} +func.func private @mem3() -> !fir.ref<, volatile> + +// ----- + // expected-error@+1 {{expected ':'}} func.func private @arr1() -> !fir.array<*> @@ -162,3 +166,24 @@ func.func private @upe() -> !fir.class> // expected-error@+1 {{invalid element type}} func.func private @upe() -> !fir.box> + +// ----- + +// expected-error@+1 {{invalid element type}} +func.func private @upe() -> !fir.box, volatile> + +// ----- + +// expected-error@+1 {{invalid element type}} +func.func private @upe() -> !fir.class> + +// ----- + +// expected-error@+1 {{invalid element type}} +func.func private @upe() -> !fir.class, volatile> + +// ----- + +// expected-error@+1 {{expected non-function type}} +func.func private @upe() -> !fir.class<, volatile> + diff --git a/flang/unittests/Optimizer/FIRTypesTest.cpp b/flang/unittests/Optimizer/FIRTypesTest.cpp index b3151b4aa7efb..28d5eb7ead25f 100644 --- a/flang/unittests/Optimizer/FIRTypesTest.cpp +++ b/flang/unittests/Optimizer/FIRTypesTest.cpp @@ -316,3 +316,39 @@ TEST_F(FIRTypesTest, getTypeAsString) { EXPECT_EQ("boxchar_c8xU", fir::getTypeAsString(fir::BoxCharType::get(&context, 1), *kindMap)); } + +TEST_F(FIRTypesTest, isVolatileType) { + mlir::Type i32 = mlir::IntegerType::get(&context, 32); + + mlir::Type i32NonVolatileRef = fir::ReferenceType::get(i32); + mlir::Type i32NonVolatileBox = fir::BoxType::get(i32); + mlir::Type i32NonVolatileClass = fir::ClassType::get(i32); + + // Ensure the default value is false + EXPECT_EQ(i32NonVolatileRef, fir::ReferenceType::get(i32, false)); + EXPECT_EQ(i32NonVolatileBox, fir::BoxType::get(i32, false)); + EXPECT_EQ(i32NonVolatileClass, fir::ClassType::get(i32, false)); + + EXPECT_FALSE(fir::isa_volatile_type(i32)); + EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileRef)); + EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileBox)); + EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileClass)); + + // Should return the same type if it's not capable of representing volatility. + EXPECT_EQ(i32, fir::updateTypeWithVolatility(i32, true)); + + mlir::Type i32VolatileRef = + fir::updateTypeWithVolatility(i32NonVolatileRef, true); + mlir::Type i32VolatileBox = + fir::updateTypeWithVolatility(i32NonVolatileBox, true); + mlir::Type i32VolatileClass = + fir::updateTypeWithVolatility(i32NonVolatileClass, true); + + EXPECT_TRUE(fir::isa_volatile_type(i32VolatileRef)); + EXPECT_TRUE(fir::isa_volatile_type(i32VolatileBox)); + EXPECT_TRUE(fir::isa_volatile_type(i32VolatileClass)); + + EXPECT_EQ(i32VolatileRef, fir::ReferenceType::get(i32, true)); + EXPECT_EQ(i32VolatileBox, fir::BoxType::get(i32, true)); + EXPECT_EQ(i32VolatileClass, fir::ClassType::get(i32, true)); +}