Skip to content

Commit 71586a6

Browse files
[mlir][bufferize] Make buffer-results-to-out-params support only functions that are neither public nor extern (#162441)
The callers of public or extern functions are unknown, so their function signatures cannot be changed.
1 parent 100db53 commit 71586a6

File tree

6 files changed

+50
-48
lines changed

6 files changed

+50
-48
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
217217
}
218218
if (!options.filterFn(&callee))
219219
return;
220+
if (callee.isExternal() || callee.isPublic())
221+
return;
222+
220223
SmallVector<Value, 6> replaceWithNewCallResults;
221224
SmallVector<Value, 6> replaceWithOutParams;
222225
for (OpResult result : op.getResults()) {
@@ -292,14 +295,14 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
292295
// function.
293296
AllocDynamicSizesMap map;
294297
for (auto func : module.getOps<func::FuncOp>()) {
298+
if (func.isExternal() || func.isPublic())
299+
continue;
295300
if (!options.filterFn(&func))
296301
continue;
297302
SmallVector<BlockArgument, 6> appendedEntryArgs;
298303
if (failed(
299304
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
300305
return failure();
301-
if (func.isExternal())
302-
continue;
303306
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
304307
return failure();
305308
}

mlir/test/Conversion/ConvertToEmitC/tosa.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
// RUN: mlir-opt --pass-pipeline=%{pipeline} %s | FileCheck %s
2121
// -----
2222

23-
// CHECK: emitc.func @main(%[[ARG0:.*]]: !emitc.array<2xf32>, %[[ARG1:.*]]: !emitc.array<2xf32>, %[[RES:.*]]: !emitc.array<2xf32>) {
23+
// CHECK: emitc.func private @main(%[[ARG0:.*]]: !emitc.array<2xf32>, %[[ARG1:.*]]: !emitc.array<2xf32>, %[[RES:.*]]: !emitc.array<2xf32>)
2424
// CHECK-DAG: %[[C0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
2525
// CHECK-DAG: %[[C1:.*]] = "emitc.constant"() <{value = 1 : index}> : () -> !emitc.size_t
2626
// CHECK-DAG: %[[C2:.*]] = "emitc.constant"() <{value = 2 : index}> : () -> !emitc.size_t
@@ -35,7 +35,7 @@
3535
// CHECK-NEXT: }
3636
// CHECK-NEXT: return
3737
// CHECK-NEXT: }
38-
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
38+
func.func private @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
3939
%0 = tosa.add %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
4040
return %0 : tensor<2xf32>
4141
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
// Note: This bufferization is not very efficient yet, but it works.
99

10-
// CHECK-LABEL: func @callee(
10+
// CHECK-LABEL: func private @callee(
1111
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>,
1212
// CHECK-SAME: %[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>) {
1313
// This alloc is not needed, but it is inserted due to the out-of-place
@@ -21,21 +21,21 @@
2121
// CHECK: return
2222
// CHECK: }
2323

24-
// CHECK-NO-LAYOUT-LABEL: func @callee(
24+
// CHECK-NO-LAYOUT-LABEL: func private @callee(
2525
// CHECK-NO-LAYOUT-SAME: %[[arg0:.*]]: memref<5xf32>,
2626
// CHECK-NO-LAYOUT-SAME: %[[arg1:.*]]: memref<5xf32>) {
2727
// CHECK-NO-LAYOUT: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
2828
// CHECK-NO-LAYOUT: memref.copy %[[arg0]], %[[alloc]]
2929
// CHECK-NO-LAYOUT: memref.store {{.*}}, %[[alloc]]
3030
// CHECK-NO-LAYOUT: memref.copy %[[alloc]], %[[arg1]]
3131

32-
// CHECK-BASELINE-LABEL: func @callee(
32+
// CHECK-BASELINE-LABEL: func private @callee(
3333
// CHECK-BASELINE-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32> {
3434
// CHECK-BASELINE: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
3535
// CHECK-BASELINE: memref.copy %[[arg0]], %[[alloc]]
3636
// CHECK-BASELINE: memref.store {{.*}}, %[[alloc]]
3737
// CHECK-BASELINE: return %[[alloc]]
38-
func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
38+
func.func private @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
3939
%c0 = arith.constant 0 : index
4040
%cst = arith.constant 8.0 : f32
4141
// This must bufferize out-of-place.
@@ -68,15 +68,15 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
6868

6969
// -----
7070

71-
// CHECK-LABEL: func @callee(
71+
// CHECK-LABEL: func private @callee(
7272
// CHECK-SAME: %{{.*}}: index,
7373
// CHECK-SAME: %[[r:.*]]: memref<2x5xf32, strided<[?, ?], offset: ?>>) {
7474
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32>
7575
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]]{{.*}} : memref<10x20xf32> to memref<2x5xf32, strided<[20, 1], offset: ?>>
7676
// CHECK: %[[casted:.*]] = memref.cast %[[subview]]
7777
// CHECK: memref.copy %[[casted]], %[[r]]
7878

79-
// CHECK-NO-LAYOUT-LABEL: func @callee(
79+
// CHECK-NO-LAYOUT-LABEL: func private @callee(
8080
// CHECK-NO-LAYOUT-SAME: %{{.*}}: index,
8181
// CHECK-NO-LAYOUT-SAME: %[[r:.*]]: memref<2x5xf32>) {
8282
// CHECK-NO-LAYOUT: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32>
@@ -88,12 +88,12 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
8888
// CHECK-NO-LAYOUT: memref.copy %[[subview]], %[[alloc2]]
8989
// CHECK-NO-LAYOUT: memref.copy %[[alloc2]], %[[r]]
9090

91-
// CHECK-BASELINE-LABEL: func @callee(
91+
// CHECK-BASELINE-LABEL: func private @callee(
9292
// CHECK-BASELINE-SAME: %{{.*}}: index) -> memref<2x5xf32, strided<[20, 1], offset: ?>> {
9393
// CHECK-BASELINE: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32>
9494
// CHECK-BASELINE: %[[subview:.*]] = memref.subview %[[alloc]]
9595
// CHECK-BASELINE: return %[[subview]]
96-
func.func @callee(%idx: index) -> tensor<2x5xf32> {
96+
func.func private @callee(%idx: index) -> tensor<2x5xf32> {
9797
%0 = bufferization.alloc_tensor() : tensor<10x20xf32>
9898
%1 = tensor.extract_slice %0[%idx, %idx][2, 5][1, 1] : tensor<10x20xf32> to tensor<2x5xf32>
9999
return %1 : tensor<2x5xf32>
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
22

33
// CHECK-LABEL: @basic({{.*}}: memref<f32> {bufferize.result})
4-
func.func @basic() -> (memref<f32>) {
4+
func.func private @basic() -> (memref<f32>) {
55
%0 = "test.source"() : () -> (memref<f32>)
66
return %0 : memref<f32>
77
}
@@ -11,7 +11,7 @@ func.func @basic() -> (memref<f32>) {
1111
// CHECK-LABEL: multiple_results
1212
// CHECK-SAME: memref<1xf32> {bufferize.result}
1313
// CHECK-SAME: memref<2xf32> {bufferize.result}
14-
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
14+
func.func private @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
1515
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
1616
return %0, %1 : memref<1xf32>, memref<2xf32>
1717
}
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
11
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-static-allocs})' %s | FileCheck %s
22

3-
// CHECK-LABEL: func @basic(
3+
// CHECK-LABEL: func private @basic(
44
// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
55
// CHECK-NOT: memref.alloc()
66
// CHECK: "test.source"(%[[ARG]]) : (memref<8x64xf32>) -> ()
77
// CHECK: return
88
// CHECK: }
9-
func.func @basic() -> (memref<8x64xf32>) {
9+
func.func private @basic() -> (memref<8x64xf32>) {
1010
%b = memref.alloc() : memref<8x64xf32>
1111
"test.source"(%b) : (memref<8x64xf32>) -> ()
1212
return %b : memref<8x64xf32>
1313
}
1414

15-
// CHECK-LABEL: func @basic_no_change(
15+
// CHECK-LABEL: func private @basic_no_change(
1616
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
1717
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
1818
// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
1919
// CHECK: return
2020
// CHECK: }
21-
func.func @basic_no_change() -> (memref<f32>) {
21+
func.func private @basic_no_change() -> (memref<f32>) {
2222
%0 = "test.source"() : () -> (memref<f32>)
2323
return %0 : memref<f32>
2424
}
2525

26-
// CHECK-LABEL: func @basic_dynamic(
26+
// CHECK-LABEL: func private @basic_dynamic(
2727
// CHECK-SAME: %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
2828
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
2929
// CHECK: "test.source"(%[[RESULT]]) : (memref<?xf32>) -> ()
3030
// CHECK: memref.copy %[[RESULT]], %[[ARG]]
3131
// CHECK: return
3232
// CHECK: }
33-
func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
33+
func.func private @basic_dynamic(%d: index) -> (memref<?xf32>) {
3434
%b = memref.alloc(%d) : memref<?xf32>
3535
"test.source"(%b) : (memref<?xf32>) -> ()
3636
return %b : memref<?xf32>
@@ -39,13 +39,13 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
3939
// -----
4040

4141
// no change due to writing to func args
42-
// CHECK-LABEL: func @return_arg(
42+
// CHECK-LABEL: func private @return_arg(
4343
// CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
4444
// CHECK: "test.source"(%[[ARG0]], %[[ARG1]])
4545
// CHECK: memref.copy
4646
// CHECK: return
4747
// CHECK: }
48-
func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
48+
func.func private @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
4949
"test.source"(%arg0, %arg1) : (memref<128x256xf32>, memref<128x256xf32>) -> ()
5050
return %arg0 : memref<128x256xf32>
5151
}
Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,66 @@
11
// RUN: mlir-opt -buffer-results-to-out-params -split-input-file -verify-diagnostics %s | FileCheck %s
22

3-
// CHECK-LABEL: func @basic(
3+
// CHECK-LABEL: func private @basic(
44
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
55
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
66
// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
77
// CHECK: return
88
// CHECK: }
9-
func.func @basic() -> (memref<f32>) {
9+
func.func private @basic() -> (memref<f32>) {
1010
%0 = "test.source"() : () -> (memref<f32>)
1111
return %0 : memref<f32>
1212
}
1313

14-
// CHECK-LABEL: func @presence_of_existing_arguments(
14+
// CHECK-LABEL: func private @presence_of_existing_arguments(
1515
// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
1616
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
1717
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<2xf32>
1818
// CHECK: memref.copy %[[RESULT]], %[[ARG1]] : memref<2xf32> to memref<2xf32>
1919
// CHECK: return
2020
// CHECK: }
21-
func.func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
21+
func.func private @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) {
2222
%0 = "test.source"() : () -> (memref<2xf32>)
2323
return %0 : memref<2xf32>
2424
}
2525

26-
// CHECK-LABEL: func @multiple_results(
26+
// CHECK-LABEL: func private @multiple_results(
2727
// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>,
2828
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) {
2929
// CHECK: %[[RESULTS:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
3030
// CHECK: memref.copy %[[RESULTS]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32>
3131
// CHECK: memref.copy %[[RESULTS]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32>
3232
// CHECK: return
3333
// CHECK: }
34-
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
34+
func.func private @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
3535
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
3636
return %0, %1 : memref<1xf32>, memref<2xf32>
3737
}
3838

39-
// CHECK-LABEL: func @non_memref_types(
39+
// CHECK-LABEL: func private @non_memref_types(
4040
// CHECK-SAME: %[[OUTPARAM:.*]]: memref<f32>) -> (i1, i32) {
4141
// CHECK: %[[RESULT1:.*]]:3 = "test.source"() : () -> (i1, memref<f32>, i32)
4242
// CHECK: memref.copy %[[RESULT1]]#1, %[[OUTPARAM]] : memref<f32> to memref<f32>
4343
// CHECK: return %[[RESULT1]]#0, %[[RESULT1]]#2 : i1, i32
4444
// CHECK: }
45-
func.func @non_memref_types() -> (i1, memref<f32>, i32) {
45+
func.func private @non_memref_types() -> (i1, memref<f32>, i32) {
4646
%0, %1, %2 = "test.source"() : () -> (i1, memref<f32>, i32)
4747
return %0, %1, %2 : i1, memref<f32>, i32
4848
}
4949

50-
// CHECK: func private @external_function(memref<f32>)
50+
// CHECK: func private @external_function() -> memref<f32>
5151
func.func private @external_function() -> (memref<f32>)
52-
// CHECK: func private @result_attrs(memref<f32> {test.some_attr})
52+
// CHECK: func private @result_attrs() -> (memref<f32> {test.some_attr})
5353
func.func private @result_attrs() -> (memref<f32> {test.some_attr})
54-
// CHECK: func private @mixed_result_attrs(memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>)
54+
// CHECK: func private @mixed_result_attrs() -> (memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>)
5555
func.func private @mixed_result_attrs() -> (memref<1xf32>, memref<2xf32> {test.some_attr}, memref<3xf32>)
5656

5757
// -----
5858

59-
// CHECK-LABEL: func private @callee(memref<1xf32>)
59+
// CHECK-LABEL: func private @callee() -> memref<1xf32>
6060
func.func private @callee() -> memref<1xf32>
6161

6262
// CHECK-LABEL: func @call_basic() {
63-
// CHECK: %[[OUTPARAM:.*]] = memref.alloc() : memref<1xf32>
64-
// CHECK: call @callee(%[[OUTPARAM]]) : (memref<1xf32>) -> ()
63+
// CHECK: %[[OUTPARAM:.*]] = call @callee() : () -> memref<1xf32>
6564
// CHECK: "test.sink"(%[[OUTPARAM]]) : (memref<1xf32>) -> ()
6665
// CHECK: return
6766
// CHECK: }
@@ -73,14 +72,12 @@ func.func @call_basic() {
7372

7473
// -----
7574

76-
// CHECK-LABEL: func private @callee(memref<1xf32>, memref<2xf32>)
75+
// CHECK-LABEL: func private @callee() -> (memref<1xf32>, memref<2xf32>)
7776
func.func private @callee() -> (memref<1xf32>, memref<2xf32>)
7877

7978
// CHECK-LABEL: func @call_multiple_result() {
80-
// CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<1xf32>
81-
// CHECK: %[[RESULT1:.*]] = memref.alloc() : memref<2xf32>
82-
// CHECK: call @callee(%[[RESULT0]], %[[RESULT1]]) : (memref<1xf32>, memref<2xf32>) -> ()
83-
// CHECK: "test.sink"(%[[RESULT0]], %[[RESULT1]]) : (memref<1xf32>, memref<2xf32>) -> ()
79+
// CHECK: %[[RESULTS:.*]]:2 = call @callee() : () -> (memref<1xf32>, memref<2xf32>)
80+
// CHECK: "test.sink"(%[[RESULTS]]#0, %[[RESULTS]]#1) : (memref<1xf32>, memref<2xf32>) -> ()
8481
// CHECK: }
8582
func.func @call_multiple_result() {
8683
%0, %1 = call @callee() : () -> (memref<1xf32>, memref<2xf32>)
@@ -89,13 +86,12 @@ func.func @call_multiple_result() {
8986

9087
// -----
9188

92-
// CHECK-LABEL: func private @callee(memref<1xf32>) -> (i1, i32)
89+
// CHECK-LABEL: func private @callee() -> (i1, memref<1xf32>, i32)
9390
func.func private @callee() -> (i1, memref<1xf32>, i32)
9491

9592
// CHECK-LABEL: func @call_non_memref_result() {
96-
// CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<1xf32>
97-
// CHECK: %[[NON_MEMREF_RESULTS:.*]]:2 = call @callee(%[[RESULT0]]) : (memref<1xf32>) -> (i1, i32)
98-
// CHECK: "test.sink"(%[[NON_MEMREF_RESULTS]]#0, %[[RESULT0]], %[[NON_MEMREF_RESULTS]]#1) : (i1, memref<1xf32>, i32) -> ()
93+
// CHECK: %[[RESULTS:.*]]:3 = call @callee() : () -> (i1, memref<1xf32>, i32)
94+
// CHECK: "test.sink"(%[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i1, memref<1xf32>, i32) -> ()
9995
// CHECK: }
10096
func.func @call_non_memref_result() {
10197
%0, %1, %2 = call @callee() : () -> (i1, memref<1xf32>, i32)
@@ -104,10 +100,13 @@ func.func @call_non_memref_result() {
104100

105101
// -----
106102

107-
func.func private @callee() -> (memref<?xf32>)
103+
func.func private @callee(%size: index) -> (memref<?xf32>) {
104+
%alloc = memref.alloc(%size) : memref<?xf32>
105+
return %alloc : memref<?xf32>
106+
}
108107

109-
func.func @call_non_memref_result() {
108+
func.func @call_non_memref_result(%size: index) {
110109
// expected-error @+1 {{cannot create out param for dynamically shaped result}}
111-
%0 = call @callee() : () -> (memref<?xf32>)
110+
%0 = call @callee(%size) : (index) -> (memref<?xf32>)
112111
"test.sink"(%0) : (memref<?xf32>) -> ()
113112
}

0 commit comments

Comments
 (0)