Skip to content

Commit e27bb5b

Browse files
checkpoint
1 parent 82077c4 commit e27bb5b

File tree

11 files changed

+105
-30
lines changed

11 files changed

+105
-30
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
150150
mlir::Block *getAllocaBlock();
151151

152152
/// Safely create a reference type to the type `eleTy`.
153-
mlir::Type getRefType(mlir::Type eleTy);
153+
mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);
154154

155155
/// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
156156
mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
111111
fir::LLVMPointerType>(t);
112112
}
113113

114+
inline bool isa_volatile_ref_type(mlir::Type t) {
115+
if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
116+
return refTy.isVolatile();
117+
return false;
118+
}
119+
114120
/// Is `t` a boxed type?
115121
inline bool isa_box_type(mlir::Type t) {
116122
return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -366,26 +366,23 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
366366

367367
let parameters = (ins
368368
"mlir::Type":$eleTy,
369-
"mlir::UnitAttr":$isVol);
369+
DefaultValuedParameter<"bool", "false">:$isVol,
370+
DefaultValuedParameter<"bool", "false">:$isAsync);
370371

371372
let skipDefaultBuilders = 1;
372373

373374
let builders = [
374-
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
375-
return Base::get(elementType.getContext(), elementType);
375+
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol, CArg<"bool", "false">:$isAsync), [{
376+
return Base::get(elementType.getContext(), elementType, isVol, isAsync);
376377
}]>,
377-
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, "bool":$isVol)>,
378-
// [{
379-
// if (isVol)
380-
// return Base::get(elementType.getContext(), elementType, mlir::UnitAttr::get(elementType.getContext()));
381-
// else
382-
// return Base::get(elementType.getContext(), elementType);
383-
//}]>,
384378
];
385379

386380
let extraClassDeclaration = [{
387381
mlir::Type getElementType() const { return getEleTy(); }
388382
bool isVolatile() const { return (bool)getIsVol(); }
383+
bool isAsync() const { return (bool)getIsAsync(); }
384+
static llvm::StringRef getVolatileKeyword() { return "volatile"; }
385+
static llvm::StringRef getAsyncKeyword() { return "async"; }
389386
}];
390387

391388
let genVerifyDecl = 1;

flang/lib/Lower/CallInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ class Fortran::lower::CallInterfaceImpl {
11121112
if (obj.attrs.test(Attrs::Value))
11131113
isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
11141114
if (obj.attrs.test(Attrs::Volatile)) {
1115-
TODO(loc, "VOLATILE in procedure interface");
1115+
// TODO(loc, "VOLATILE in procedure interface");
11161116
addMLIRAttr(fir::getVolatileAttrName());
11171117
}
11181118
// obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,37 @@ class HlfirDesignatorBuilder {
223223
designatorNode, getConverter().getFoldingContext(),
224224
/*namedConstantSectionsAreAlwaysContiguous=*/false))
225225
return fir::BoxType::get(resultValueType);
226+
227+
// TODO: handle async references
228+
bool isVolatile = false, isAsync = false;
229+
230+
// Check if the base type is volatile
231+
if (partInfo.base.has_value()) {
232+
mlir::Type baseType = partInfo.base.value().getType();
233+
isVolatile = fir::isa_volatile_ref_type(baseType);
234+
}
235+
236+
auto isVolatileSymbol = [&](const Fortran::semantics::Symbol &symbol) {
237+
return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE);
238+
};
239+
240+
// Check if this should be a volatile reference
241+
if constexpr (std::is_same_v<std::decay_t<T>,
242+
Fortran::evaluate::SymbolRef>) {
243+
if (isVolatileSymbol(designatorNode.get()))
244+
isVolatile = true;
245+
} else if constexpr (std::is_same_v<std::decay_t<T>,
246+
Fortran::evaluate::Component>) {
247+
if (isVolatileSymbol(designatorNode.GetLastSymbol()))
248+
isVolatile = true;
249+
}
250+
251+
// If it's a reference to a ref, account for it
252+
if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
253+
resultValueType = refTy.getEleTy();
254+
226255
// Other designators can be handled as raw addresses.
227-
return fir::ReferenceType::get(resultValueType);
256+
return fir::ReferenceType::get(resultValueType, isVolatile, isAsync);
228257
}
229258

230259
template <typename T>
@@ -269,6 +298,7 @@ class HlfirDesignatorBuilder {
269298
partInfo.componentName, partInfo.componentShape, partInfo.subscripts,
270299
partInfo.substring, partInfo.complexPart, partInfo.resultShape,
271300
partInfo.typeParams, attributes);
301+
llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << designate << "\n" << designatorType << "\n";
272302
if (auto elementalAddrOp = getVectorSubscriptElementAddrOp())
273303
builder.setInsertionPoint(*elementalAddrOp);
274304
return mlir::cast<fir::FortranVariableOpInterface>(
@@ -414,10 +444,13 @@ class HlfirDesignatorBuilder {
414444
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
415445
return fir::SequenceType::get(seqTy.getShape(), newEleTy);
416446
})
417-
.Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
418-
fir::ClassType>([&](auto t) -> mlir::Type {
419-
using FIRT = decltype(t);
420-
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
447+
.Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
448+
[&](auto t) -> mlir::Type {
449+
using FIRT = decltype(t);
450+
return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
451+
})
452+
.Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
453+
return fir::ReferenceType::get(changeElementType(refTy.getEleTy(), newEleTy), refTy.isVolatile());
421454
})
422455
.Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
423456
}
@@ -1796,6 +1829,7 @@ class HlfirBuilder {
17961829
/*complexPart=*/std::nullopt,
17971830
/*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{},
17981831
fir::FortranVariableFlagsAttr{});
1832+
llvm::dbgs() << __LINE__ << " " << newParent << "\n";
17991833
currentParent = hlfir::EntityWithAttributes{newParent};
18001834
}
18011835
valuesAndParents.emplace_back(
@@ -1808,6 +1842,7 @@ class HlfirBuilder {
18081842
auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
18091843
auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
18101844
std::string name = converter.getRecordTypeFieldName(sym);
1845+
const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());
18111846

18121847
// Generate DesignateOp for the component.
18131848
// The designator's result type is just a reference to the component type,
@@ -1818,7 +1853,6 @@ class HlfirBuilder {
18181853
assert(compType && "failed to retrieve component type");
18191854
mlir::Value compShape =
18201855
designatorBuilder.genComponentShape(sym, compType);
1821-
mlir::Type designatorType = builder.getRefType(compType);
18221856

18231857
mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
18241858
llvm::SmallVector<mlir::Value, 1> typeParams;
@@ -1839,6 +1873,7 @@ class HlfirBuilder {
18391873
// Convert component symbol attributes to variable attributes.
18401874
fir::FortranVariableFlagsAttr attrs =
18411875
Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
1876+
mlir::Type designatorType = builder.getRefType(compType, isVolatile);
18421877

18431878
// Get the component designator.
18441879
auto lhs = builder.create<hlfir::DesignateOp>(
@@ -1847,6 +1882,7 @@ class HlfirBuilder {
18471882
/*substring=*/mlir::ValueRange{},
18481883
/*complexPart=*/std::nullopt,
18491884
/*shape=*/compShape, typeParams, attrs);
1885+
llvm::dbgs() << __LINE__ << " " << lhs << "\n";
18501886

18511887
if (attrs && bitEnumContainsAny(attrs.getFlags(),
18521888
fir::FortranVariableFlagsEnum::pointer)) {

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
104104
return modOp.lookupSymbol<fir::GlobalOp>(name);
105105
}
106106

107-
mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
107+
mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
108108
assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
109-
return fir::ReferenceType::get(eleTy);
109+
return fir::ReferenceType::get(eleTy, isVolatile);
110110
}
111111

112112
mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,10 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
809809
} else if (fir::isRecordWithTypeParameters(eleTy)) {
810810
return fir::BoxType::get(eleTy);
811811
}
812-
return fir::ReferenceType::get(eleTy);
812+
const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
813+
auto newty = fir::ReferenceType::get(eleTy, isVolatile);
814+
llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << variable << " " << variable.getType() << " newty:" << newty << " isvol:" << isVolatile << "\n";
815+
return newty;
813816
}
814817

815818
mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3218,6 +3218,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
32183218
mlir::ConversionPatternRewriter &rewriter) const override {
32193219

32203220
mlir::Type llvmLoadTy = convertObjectType(load.getType());
3221+
const bool isVolatile = fir::isa_volatile_ref_type(load.getMemref().getType());
32213222
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
32223223
// fir.box is a special case because it is considered an ssa value in
32233224
// fir, but it is lowered as a pointer to a descriptor. So
@@ -3247,16 +3248,17 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
32473248
mlir::Value boxSize =
32483249
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
32493250
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
3250-
loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
3251+
loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);
32513252

32523253
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
32533254
memcpy.setTBAATags(*optionalTag);
32543255
else
32553256
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
32563257
rewriter.replaceOp(load, newBoxStorage);
32573258
} else {
3259+
auto memref = adaptor.getOperands()[0];
32583260
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
3259-
load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
3261+
load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
32603262
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
32613263
loadOp.setTBAATags(*optionalTag);
32623264
else

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,18 +1057,41 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
10571057
// ReferenceType
10581058
//===----------------------------------------------------------------------===//
10591059

1060-
// `ref` `<` type `>`
1060+
// `ref` `<` type (, volatile)? (, async)? `>`
10611061
mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
1062-
return parseTypeSingleton<fir::ReferenceType>(parser);
1062+
if (parser.parseLess())
1063+
return {};
1064+
mlir::Type eleTy;
1065+
if (parser.parseType(eleTy))
1066+
return {};
1067+
bool isVolatile = false;
1068+
bool isAsync = false;
1069+
while (parser.parseOptionalComma()) {
1070+
if (parser.parseOptionalKeyword(getVolatileKeyword())) {
1071+
isVolatile = true;
1072+
} else if (parser.parseOptionalKeyword(getAsyncKeyword())) {
1073+
isAsync = true;
1074+
} else {
1075+
return {};
1076+
}
1077+
}
1078+
if (parser.parseGreater())
1079+
return {};
1080+
return ReferenceType::get(eleTy, isVolatile, isAsync);
10631081
}
10641082

10651083
void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
1066-
printer << "<" << getEleTy() << '>';
1084+
printer << "<" << getEleTy();
1085+
if (isVolatile())
1086+
printer << ", volatile";
1087+
if (isAsync())
1088+
printer << ", async";
1089+
printer << '>';
10671090
}
10681091

10691092
llvm::LogicalResult fir::ReferenceType::verify(
10701093
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1071-
mlir::Type eleTy, mlir::UnitAttr isVolatile) {
1094+
mlir::Type eleTy, bool isVolatile, bool isAsync) {
10721095
if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
10731096
ReferenceType, TypeDescType>(eleTy))
10741097
return emitError() << "cannot build a reference to type: " << eleTy << '\n';

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
214214
auto nameAttr = builder.getStringAttr(uniq_name);
215215
mlir::Type inputType = memref.getType();
216216
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
217+
if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
218+
bitEnumContainsAny(fortran_attrs.getFlags(),
219+
fir::FortranVariableFlagsEnum::fortran_volatile)) {
220+
auto refType = mlir::cast<fir::ReferenceType>(inputType);
221+
inputType = fir::ReferenceType::get(refType.getEleTy(), true);
222+
memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
223+
}
217224
mlir::Type hlfirVariableType =
218225
getHLFIRVariableType(inputType, hasExplicitLbs);
219226
build(builder, result, {hlfirVariableType, inputType}, memref, shape,

0 commit comments

Comments
 (0)