@@ -649,12 +649,17 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
649649 .Case <fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
650650 return fir::SequenceType::get (seqTy.getShape (), newElementType);
651651 })
652- .Case <fir::PointerType, fir::HeapType, fir::ReferenceType,
653- fir::ClassType>([&](auto t) -> mlir::Type {
654- using FIRT = decltype (t);
655- return FIRT::get (
656- changeElementType (t.getEleTy (), newElementType, turnBoxIntoClass));
652+ .Case <fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
653+ auto newEleTy = changeElementType (refTy.getEleTy (), newElementType,
654+ turnBoxIntoClass);
655+ return fir::ReferenceType::get (newEleTy, refTy.isVolatile ());
657656 })
657+ .Case <fir::PointerType, fir::HeapType, fir::ClassType>(
658+ [&](auto t) -> mlir::Type {
659+ using FIRT = decltype (t);
660+ return FIRT::get (changeElementType (t.getEleTy (), newElementType,
661+ turnBoxIntoClass));
662+ })
658663 .Case <fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
659664 mlir::Type newInnerType =
660665 changeElementType (t.getEleTy (), newElementType, false );
@@ -1057,18 +1062,38 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
10571062// ReferenceType
10581063// ===----------------------------------------------------------------------===//
10591064
1060- // `ref` `<` type `>`
1065+ // `ref` `<` type (`, volatile` $volatile^)? `>`
10611066mlir::Type fir::ReferenceType::parse (mlir::AsmParser &parser) {
1062- return parseTypeSingleton<fir::ReferenceType>(parser);
1067+ if (parser.parseLess ())
1068+ return {};
1069+
1070+ mlir::Type eleTy;
1071+ if (parser.parseType (eleTy))
1072+ return {};
1073+
1074+ bool isVolatile = false ;
1075+ if (!parser.parseOptionalComma ()) {
1076+ if (parser.parseKeyword (getVolatileKeyword ())) {
1077+ return {};
1078+ }
1079+ isVolatile = true ;
1080+ }
1081+
1082+ if (parser.parseGreater ())
1083+ return {};
1084+ return get (eleTy, isVolatile);
10631085}
10641086
10651087void fir::ReferenceType::print (mlir::AsmPrinter &printer) const {
1066- printer << " <" << getEleTy () << ' >' ;
1088+ printer << " <" << getEleTy ();
1089+ if (isVolatile ())
1090+ printer << " , " << getVolatileKeyword ();
1091+ printer << ' >' ;
10671092}
10681093
10691094llvm::LogicalResult fir::ReferenceType::verify (
1070- llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1071- mlir::Type eleTy ) {
1095+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
1096+ bool isVolatile ) {
10721097 if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
10731098 ReferenceType, TypeDescType>(eleTy))
10741099 return emitError () << " cannot build a reference to type: " << eleTy << ' \n ' ;
@@ -1319,11 +1344,15 @@ changeTypeShape(mlir::Type type,
13191344 return fir::SequenceType::get (*newShape, seqTy.getEleTy ());
13201345 return seqTy.getEleTy ();
13211346 })
1322- .Case <fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
1323- fir::ClassType>([&](auto t) -> mlir::Type {
1324- using FIRT = decltype (t);
1325- return FIRT::get (changeTypeShape (t.getEleTy (), newShape));
1347+ .Case <fir::ReferenceType>([&](fir::ReferenceType rt) -> mlir::Type {
1348+ return fir::ReferenceType::get (changeTypeShape (rt.getEleTy (), newShape),
1349+ rt.isVolatile ());
13261350 })
1351+ .Case <fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
1352+ [&](auto t) -> mlir::Type {
1353+ using FIRT = decltype (t);
1354+ return FIRT::get (changeTypeShape (t.getEleTy (), newShape));
1355+ })
13271356 .Default ([&](mlir::Type t) -> mlir::Type {
13281357 assert ((fir::isa_trivial (t) || llvm::isa<fir::RecordType>(t) ||
13291358 llvm::isa<mlir::NoneType>(t) ||
0 commit comments