Skip to content

Commit b8b72cd

Browse files
committed
[flang] fix regression with optional after PR125059
1 parent 41910f7 commit b8b72cd

File tree

4 files changed

+160
-14
lines changed

4 files changed

+160
-14
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,7 @@ class Entity : public mlir::Value {
150150
return base.getDefiningOp<fir::FortranVariableOpInterface>();
151151
}
152152

153-
bool isOptional() const {
154-
auto varIface = getIfVariableInterface();
155-
return varIface ? varIface.isOptional() : false;
156-
}
153+
bool mayBeOptional() const;
157154

158155
bool isParameter() const {
159156
auto varIface = getIfVariableInterface();
@@ -210,7 +207,8 @@ class EntityWithAttributes : public Entity {
210207
using CleanupFunction = std::function<void()>;
211208
std::pair<fir::ExtendedValue, std::optional<CleanupFunction>>
212209
translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
213-
Entity entity, bool contiguousHint = false);
210+
Entity entity, bool contiguousHint = false,
211+
bool keepScalarOptionalBoxed = false);
214212

215213
/// Function to translate FortranVariableOpInterface to fir::ExtendedValue.
216214
/// It may generates IR to unbox fir.boxchar, but has otherwise no side effects

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ bool hlfir::Entity::mayHaveNonDefaultLowerBounds() const {
221221
return true;
222222
}
223223

224+
mlir::Operation* traverseConverts(mlir::Operation *op) {
225+
while (auto convert = llvm::dyn_cast_or_null<fir::ConvertOp>(op))
226+
op = convert.getValue().getDefiningOp();
227+
return op;
228+
}
229+
230+
bool hlfir::Entity::mayBeOptional() const {
231+
if (auto varIface = getIfVariableInterface())
232+
return varIface.isOptional();
233+
if (!isVariable())
234+
return false;
235+
// TODO: introduce a fir type to better identify optionals.
236+
if (mlir::Operation* op = traverseConverts(getDefiningOp()))
237+
return !llvm::isa<fir::AllocaOp, fir::AllocMemOp, fir::ReboxOp,
238+
fir::EmboxOp>(op);
239+
return true;
240+
}
241+
224242
fir::FortranVariableOpInterface
225243
hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
226244
const fir::ExtendedValue &exv, llvm::StringRef name,
@@ -963,9 +981,68 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
963981
return outerLoop->getResults();
964982
}
965983

984+
template <typename Lambda>
985+
static fir::ExtendedValue
986+
conditionnalyEvaluate(mlir::Location loc, fir::FirOpBuilder &builder,
987+
mlir::Value condition, const Lambda &genIfTrue) {
988+
mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
989+
990+
// Evaluate in some region that will be moved into the actual ifOp (the actual
991+
// ifOp can only be created when the result types are known).
992+
auto badIfOp = builder.create<fir::IfOp>(loc, condition.getType(), condition,
993+
/*withElseRegion=*/false);
994+
mlir::Block *preparationBlock = &badIfOp.getThenRegion().front();
995+
builder.setInsertionPointToStart(preparationBlock);
996+
fir::ExtendedValue result = genIfTrue();
997+
fir::ResultOp resultOp = result.match(
998+
[&](const fir::CharBoxValue &box) -> fir::ResultOp {
999+
return builder.create<fir::ResultOp>(
1000+
loc, mlir::ValueRange{box.getAddr(), box.getLen()});
1001+
},
1002+
[&](const mlir::Value &addr) -> fir::ResultOp {
1003+
return builder.create<fir::ResultOp>(loc, addr);
1004+
},
1005+
[&](const auto &) -> fir::ResultOp {
1006+
TODO(loc, "unboxing non scalar optional fir.box");
1007+
});
1008+
builder.restoreInsertionPoint(insertPt);
1009+
1010+
// Create actual fir.if operation.
1011+
auto ifOp =
1012+
builder.create<fir::IfOp>(loc, resultOp->getOperandTypes(), condition,
1013+
/*withElseRegion=*/true);
1014+
// Move evaluation into Then block,
1015+
preparationBlock->moveBefore(&ifOp.getThenRegion().back());
1016+
ifOp.getThenRegion().back().erase();
1017+
// Create absent result in the Else block.
1018+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1019+
llvm::SmallVector<mlir::Value> absentValues;
1020+
for (mlir::Type resTy : ifOp->getResultTypes()) {
1021+
if (fir::isa_ref_type(resTy) || fir::isa_box_type(resTy))
1022+
absentValues.emplace_back(builder.create<fir::AbsentOp>(loc, resTy));
1023+
else
1024+
absentValues.emplace_back(builder.create<fir::ZeroOp>(loc, resTy));
1025+
}
1026+
builder.create<fir::ResultOp>(loc, absentValues);
1027+
badIfOp->erase();
1028+
1029+
// Build fir::ExtendedValue from the result values.
1030+
builder.setInsertionPointAfter(ifOp);
1031+
return result.match(
1032+
[&](const fir::CharBoxValue &box) -> fir::ExtendedValue {
1033+
return fir::CharBoxValue{ifOp.getResult(0), ifOp.getResult(1)};
1034+
},
1035+
[&](const mlir::Value &) -> fir::ExtendedValue {
1036+
return ifOp.getResult(0);
1037+
},
1038+
[&](const auto &) -> fir::ExtendedValue {
1039+
TODO(loc, "unboxing non scalar optional fir.box");
1040+
});
1041+
}
1042+
9661043
static fir::ExtendedValue translateVariableToExtendedValue(
9671044
mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity variable,
968-
bool forceHlfirBase = false, bool contiguousHint = false) {
1045+
bool forceHlfirBase = false, bool contiguousHint = false, bool keepScalarOptionalBoxed = false) {
9691046
assert(variable.isVariable() && "must be a variable");
9701047
// When going towards FIR, use the original base value to avoid
9711048
// introducing descriptors at runtime when they are not required.
@@ -984,14 +1061,33 @@ static fir::ExtendedValue translateVariableToExtendedValue(
9841061
const bool contiguous = variable.isSimplyContiguous() || contiguousHint;
9851062
const bool isAssumedRank = variable.isAssumedRank();
9861063
if (!contiguous || variable.isPolymorphic() ||
987-
variable.isDerivedWithLengthParameters() || variable.isOptional() ||
1064+
variable.isDerivedWithLengthParameters() ||
9881065
isAssumedRank) {
9891066
llvm::SmallVector<mlir::Value> nonDefaultLbounds;
9901067
if (!isAssumedRank)
9911068
nonDefaultLbounds = getNonDefaultLowerBounds(loc, builder, variable);
9921069
return fir::BoxValue(base, nonDefaultLbounds,
9931070
getExplicitTypeParams(variable));
9941071
}
1072+
if (variable.mayBeOptional()) {
1073+
if (!keepScalarOptionalBoxed && variable.isScalar()) {
1074+
mlir::Value isPresent = builder.create<fir::IsPresentOp>(
1075+
loc, builder.getI1Type(), variable);
1076+
return conditionnalyEvaluate(
1077+
loc, builder, isPresent, [&]() -> fir::ExtendedValue {
1078+
mlir::Value base = genVariableRawAddress(loc, builder, variable);
1079+
if (variable.isCharacter()) {
1080+
mlir::Value len =
1081+
genCharacterVariableLength(loc, builder, variable);
1082+
return fir::CharBoxValue{base, len};
1083+
}
1084+
return base;
1085+
});
1086+
}
1087+
llvm::SmallVector<mlir::Value> nonDefaultLbounds = getNonDefaultLowerBounds(loc, builder, variable);
1088+
return fir::BoxValue(base, nonDefaultLbounds,
1089+
getExplicitTypeParams(variable));
1090+
}
9951091
// Otherwise, the variable can be represented in a fir::ExtendedValue
9961092
// without the overhead of a fir.box.
9971093
base = genVariableRawAddress(loc, builder, variable);
@@ -1035,10 +1131,12 @@ hlfir::translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
10351131

10361132
std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
10371133
hlfir::translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
1038-
hlfir::Entity entity, bool contiguousHint) {
1134+
hlfir::Entity entity, bool contiguousHint,
1135+
bool keepScalarOptionalBoxed) {
10391136
if (entity.isVariable())
10401137
return {translateVariableToExtendedValue(loc, builder, entity, false,
1041-
contiguousHint),
1138+
contiguousHint,
1139+
keepScalarOptionalBoxed),
10421140
std::nullopt};
10431141

10441142
if (entity.isProcedure()) {

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,14 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
121121
// simplified since the fir.box lowered here are now guarenteed to
122122
// contain the local lower bounds thanks to the hlfir.declare (the extra
123123
// rebox can be removed).
124-
auto [exv, cleanup] =
125-
hlfir::translateToExtendedValue(loc, builder, entity);
124+
// When taking arguments as descriptors, the runtime expect absent
125+
// OPTIONAL to be a nullptr to a descriptor, lowering has already
126+
// prepared such descriptors // as needed, hence set
127+
// keepScalarOptionalBoxed to avoid building descriptors with a null
128+
// address for them.
129+
auto [exv, cleanup] = hlfir::translateToExtendedValue(
130+
loc, builder, entity, /*contiguous=*/false,
131+
/*keepScalarOptionalBoxed=*/true);
126132
if (cleanup)
127133
cleanupFns.push_back(*cleanup);
128134
ret.emplace_back(exv);

flang/test/HLFIR/assign-codegen.fir

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,55 @@ func.func @test_upoly_expr_assignment(%arg0: !fir.class<!fir.array<?xnone>> {fir
429429
// CHECK: }
430430

431431
func.func @test_scalar_box(%arg0: f32, %arg1: !fir.box<!fir.ptr<f32>>) {
432-
hlfir.assign %arg0 to %arg1 : f32, !fir.box<!fir.ptr<f32>>
432+
%x = fir.declare %arg1 {uniq_name = "x"} : (!fir.box<!fir.ptr<f32>>) -> !fir.box<!fir.ptr<f32>>
433+
hlfir.assign %arg0 to %x : f32, !fir.box<!fir.ptr<f32>>
433434
return
434435
}
435436
// CHECK-LABEL: func.func @test_scalar_box(
436437
// CHECK-SAME: %[[VAL_0:.*]]: f32,
437438
// CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.ptr<f32>>) {
438-
// CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.ptr<f32>>) -> !fir.ptr<f32>
439-
// CHECK: fir.store %[[VAL_0]] to %[[VAL_2]] : !fir.ptr<f32>
439+
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_1]] {uniq_name = "x"} : (!fir.box<!fir.ptr<f32>>) -> !fir.box<!fir.ptr<f32>>
440+
// CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_2]] : (!fir.box<!fir.ptr<f32>>) -> !fir.ptr<f32>
441+
// CHECK: fir.store %[[VAL_0]] to %[[VAL_3]] : !fir.ptr<f32>
442+
443+
func.func @test_scalar_opt_box(%arg0: f32, %arg1: !fir.box<!fir.ptr<f32>>) {
444+
%x = fir.declare %arg1 {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "x"} : (!fir.box<!fir.ptr<f32>>) -> !fir.box<!fir.ptr<f32>>
445+
hlfir.assign %arg0 to %x : f32, !fir.box<!fir.ptr<f32>>
446+
return
447+
}
448+
// CHECK-LABEL: func.func @test_scalar_opt_box(
449+
// CHECK-SAME: %[[VAL_0:.*]]: f32,
450+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.ptr<f32>>) {
451+
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_1]] {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "x"} : (!fir.box<!fir.ptr<f32>>) -> !fir.box<!fir.ptr<f32>>
452+
// CHECK: %[[VAL_3:.*]] = fir.is_present %[[VAL_2]] : (!fir.box<!fir.ptr<f32>>) -> i1
453+
// CHECK: %[[VAL_4:.*]] = fir.if %[[VAL_3]] -> (!fir.ptr<f32>) {
454+
// CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_2]] : (!fir.box<!fir.ptr<f32>>) -> !fir.ptr<f32>
455+
// CHECK: fir.result %[[VAL_5]] : !fir.ptr<f32>
456+
// CHECK: } else {
457+
// CHECK: %[[VAL_6:.*]] = fir.absent !fir.ptr<f32>
458+
// CHECK: fir.result %[[VAL_6]] : !fir.ptr<f32>
459+
// CHECK: }
460+
// CHECK: fir.store %[[VAL_0]] to %[[VAL_4]] : !fir.ptr<f32>
461+
462+
func.func @test_scalar_opt_char_box(%arg0: !fir.ref<!fir.char<1,10>>, %arg1: !fir.box<!fir.char<1,?>>) {
463+
%x = fir.declare %arg1 {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "x"} : (!fir.box<!fir.char<1,?>>) -> !fir.box<!fir.char<1,?>>
464+
hlfir.assign %arg0 to %x : !fir.ref<!fir.char<1,10>>, !fir.box<!fir.char<1,?>>
465+
return
466+
}
467+
// CHECK-LABEL: func.func @test_scalar_opt_char_box(
468+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.char<1,10>>,
469+
// CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.char<1,?>>) {
470+
// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_1]] {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "x"} : (!fir.box<!fir.char<1,?>>) -> !fir.box<!fir.char<1,?>>
471+
// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index
472+
// CHECK: %[[VAL_4:.*]] = fir.is_present %[[VAL_2]] : (!fir.box<!fir.char<1,?>>) -> i1
473+
// CHECK: %[[VAL_5:.*]]:2 = fir.if %[[VAL_4]] -> (!fir.ref<!fir.char<1,?>>, index) {
474+
// CHECK: %[[VAL_6:.*]] = fir.box_addr %[[VAL_2]] : (!fir.box<!fir.char<1,?>>) -> !fir.ref<!fir.char<1,?>>
475+
// CHECK: %[[VAL_7:.*]] = fir.box_elesize %[[VAL_2]] : (!fir.box<!fir.char<1,?>>) -> index
476+
// CHECK: fir.result %[[VAL_6]], %[[VAL_7]] : !fir.ref<!fir.char<1,?>>, index
477+
// CHECK: } else {
478+
// CHECK: %[[VAL_8:.*]] = fir.absent !fir.ref<!fir.char<1,?>>
479+
// CHECK: %[[VAL_9:.*]] = fir.zero_bits index
480+
// CHECK: fir.result %[[VAL_8]], %[[VAL_9]] : !fir.ref<!fir.char<1,?>>, index
481+
// CHECK: }
482+
// ...
483+
// CHECK: fir.call @llvm.memmove.p0.p0.i64(

0 commit comments

Comments
 (0)