Skip to content

Commit 3beec2f

Browse files
authored
[flang] do not rely on existing fir.convert in TargetRewrite (#157413)
TargetRewrite is doing a shallow rewrite of function signatures. It is only rewriting function definitions (FuncOp), calls (CallOp) and AddressOfOp. It is not trying to visit each operations that may have an operand with a function type. It therefore needs function signature casts around the operations it is rewriting. Currently, these casts were not inserted after AddressOfOp rewrites because lowering tends to always insert function cast after generating AddressOfOp to the void type so the pass relied on implicitly updating this cast operand type to get the required cast. This is brittle because there is no guarantee such convert must be here and canonicalization and passes may remove them. Insert a cast after on the result of rewritten operations. If it is redundant, it will be canonicalized away later.
1 parent b45f1fb commit 3beec2f

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
13361336
private:
13371337
// Replace `op` and remove it.
13381338
void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
1339-
op->replaceAllUsesWith(newValues);
1339+
llvm::SmallVector<mlir::Value> casts;
1340+
for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) {
1341+
if (oldValue.getType() == newValue.getType())
1342+
casts.push_back(newValue);
1343+
else
1344+
casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(),
1345+
oldValue.getType(), newValue));
1346+
}
1347+
op->replaceAllUsesWith(casts);
13401348
op->dropAllReferences();
13411349
op->erase();
13421350
}

flang/test/Fir/struct-return-x86-64.fir

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
1717
%1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
1818
return %1 : () -> ()
1919
}
20+
func.func @test_addr_of_inreg_2() -> (() -> !fits_in_reg) {
21+
%0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
22+
return %0 : () -> !fits_in_reg
23+
}
2024
func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
2125
%0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
2226
fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
@@ -62,8 +66,15 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
6266

6367
// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
6468
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
65-
// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
66-
// CHECK: return %[[VAL_1]] : () -> ()
69+
// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
70+
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) -> (() -> ())
71+
// CHECK: return %[[VAL_2]] : () -> ()
72+
// CHECK: }
73+
74+
// CHECK-LABEL: func.func @test_addr_of_inreg_2() -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>) {
75+
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
76+
// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> !fir.type<t1{i:f32,j:i32,k:f32}>)
77+
// CHECK: return %[[VAL_1]] : () -> !fir.type<t1{i:f32,j:i32,k:f32}>
6778
// CHECK: }
6879

6980
// CHECK-LABEL: func.func @test_dispatch_inreg(
@@ -95,8 +106,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data
95106

96107
// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
97108
// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
98-
// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
99-
// CHECK: return %[[VAL_1]] : () -> ()
109+
// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> !fir.type<t2{i:!fir.array<5xf32>}>)
110+
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<t2{i:!fir.array<5xf32>}>) -> (() -> ())
111+
// CHECK: return %[[VAL_2]] : () -> ()
100112
// CHECK: }
101113

102114
// CHECK-LABEL: func.func @test_dispatch_sret(

0 commit comments

Comments
 (0)