Skip to content

Commit 3f99d2f

Browse files
[mlir][bufferize] Make drop-equivalent-buffer-results only support functions that are neither public nor extern (llvm#163001)
The callers of public or extern functions are unknown, so their function signatures cannot be changed.
1 parent d7fc770 commit 3f99d2f

File tree

6 files changed

+42
-41
lines changed

6 files changed

+42
-41
lines changed

mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
6464
module.walk([&](func::CallOp callOp) {
6565
if (func::FuncOp calledFunc =
6666
dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
67-
callerMap[calledFunc].insert(callOp);
67+
if (!calledFunc.isPublic() && !calledFunc.isExternal())
68+
callerMap[calledFunc].insert(callOp);
6869
}
6970
});
7071

7172
for (auto funcOp : module.getOps<func::FuncOp>()) {
72-
if (funcOp.isExternal())
73+
if (funcOp.isExternal() || funcOp.isPublic())
7374
continue;
7475
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
7576
// TODO: Support functions with multiple blocks.

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
6565

6666
// -----
6767

68-
func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
68+
func.func private @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
6969
func.return %A : tensor<?xf32>
7070
}
71-
// CHECK-LABEL: func @return_arg
71+
// CHECK-LABEL: func private @return_arg
7272
// CHECK-SAME: %[[A:.*]]: memref<?xf32
7373
// CHECK-NOT: return %[[A]]
7474

75-
// NO-DROP-LABEL: func @return_arg
75+
// NO-DROP-LABEL: func private @return_arg
7676
// NO-DROP-SAME: %[[A:.*]]: memref<?xf32
7777
// NO-DROP: return %[[A]]

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ func.func @func_without_tensor_args(%v : vector<10xf32>) -> () {
171171
// Bufferization of a function that is reading and writing. %t0 is writable, so
172172
// no copy should be inserted.
173173

174-
// CHECK-LABEL: func @inner_func(
174+
// CHECK-LABEL: func private @inner_func(
175175
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
176-
func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
176+
func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
177177
// CHECK-NOT: copy
178178
%f = arith.constant 1.0 : f32
179179
%c0 = arith.constant 0 : index
@@ -186,9 +186,9 @@ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
186186
return %0, %1 : tensor<?xf32>, f32
187187
}
188188

189-
// CHECK-LABEL: func @call_func_with_non_tensor_return(
189+
// CHECK-LABEL: func private @call_func_with_non_tensor_return(
190190
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
191-
func.func @call_func_with_non_tensor_return(
191+
func.func private @call_func_with_non_tensor_return(
192192
%t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
193193
// CHECK-NOT: alloc
194194
// CHECK-NOT: copy
@@ -203,9 +203,9 @@ func.func @call_func_with_non_tensor_return(
203203
// Bufferization of a function that is reading and writing. %t0 is not writable,
204204
// so a copy is needed.
205205

206-
// CHECK-LABEL: func @inner_func(
206+
// CHECK-LABEL: func private @inner_func(
207207
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
208-
func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
208+
func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
209209
// CHECK-NOT: copy
210210
%f = arith.constant 1.0 : f32
211211
%c0 = arith.constant 0 : index
@@ -276,10 +276,10 @@ func.func @main(%t: tensor<?xf32> {bufferization.writable = false}) -> (f32) {
276276

277277
// This function does not read, just write. We need an alloc, but no copy.
278278

279-
// CHECK-LABEL: func @does_not_read(
279+
// CHECK-LABEL: func private @does_not_read(
280280
// CHECK-NOT: alloc
281281
// CHECK-NOT: copy
282-
func.func @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> {
282+
func.func private @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> {
283283
%f0 = arith.constant 0.0 : f32
284284
%r = linalg.fill ins(%f0 : f32) outs(%t : tensor<?xf32>) -> tensor<?xf32>
285285
return %r : tensor<?xf32>
@@ -354,9 +354,9 @@ func.func @main() {
354354

355355
// A write inside an scf.execute_region. An equivalent tensor is yielded.
356356

357-
// CHECK-LABEL: func @execute_region_test(
357+
// CHECK-LABEL: func private @execute_region_test(
358358
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
359-
func.func @execute_region_test(%t1 : tensor<?xf32>)
359+
func.func private @execute_region_test(%t1 : tensor<?xf32>)
360360
-> (f32, tensor<?xf32>, f32)
361361
{
362362
%f1 = arith.constant 0.0 : f32
@@ -397,11 +397,11 @@ func.func @no_inline_execute_region_not_canonicalized() {
397397
// CHECK: func private @some_external_func(memref<?xf32, strided<[?], offset: ?>>)
398398
func.func private @some_external_func(tensor<?xf32>)
399399

400-
// CHECK: func @scf_for_with_tensor_insert_slice(
400+
// CHECK: func private @scf_for_with_tensor_insert_slice(
401401
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
402402
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
403403
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
404-
func.func @scf_for_with_tensor_insert_slice(
404+
func.func private @scf_for_with_tensor_insert_slice(
405405
%A : tensor<?xf32>, %B : tensor<?xf32>, %C : tensor<4xf32>,
406406
%lb : index, %ub : index, %step : index)
407407
-> (tensor<?xf32>, tensor<?xf32>)
@@ -456,11 +456,11 @@ func.func @bar(
456456

457457
// -----
458458

459-
// CHECK: func @init_and_dot(
459+
// CHECK: func private @init_and_dot(
460460
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>>
461461
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>>
462462
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, strided<[], offset: ?>>
463-
func.func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
463+
func.func private @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
464464
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
465465
%v0 = arith.constant 0.0 : f32
466466

@@ -574,9 +574,9 @@ func.func @entry(%A : tensor<?xf32> {bufferization.buffer_layout = affine_map<(i
574574

575575
// No alloc or copy inside of the loop.
576576

577-
// CHECK-LABEL: func @inner_func(
577+
// CHECK-LABEL: func private @inner_func(
578578
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
579-
func.func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
579+
func.func private @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
580580
%f = arith.constant 1.0 : f32
581581
%c0 = arith.constant 0 : index
582582
// CHECK: memref.store %{{.*}}, %[[arg0]]

mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
// TODO: Some test cases from this file should be moved to other dialects.
1212

13-
// CHECK-LABEL: func @fill_inplace(
13+
// CHECK-LABEL: func private @fill_inplace(
1414
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
15-
// CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref<?xf32>) {
16-
func.func @fill_inplace(
15+
// CHECK-NO-LAYOUT-MAP-LABEL: func private @fill_inplace(%{{.*}}: memref<?xf32>) {
16+
func.func private @fill_inplace(
1717
%A : tensor<?xf32> {bufferization.writable = true})
1818
-> tensor<?xf32>
1919
{
@@ -56,10 +56,10 @@ func.func @not_inplace(
5656
// -----
5757

5858

59-
// CHECK-LABEL: func @not_inplace
59+
// CHECK-LABEL: func private @not_inplace
6060
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>) {
61-
// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref<?x?xf32>) {
62-
func.func @not_inplace(
61+
// CHECK-NO-LAYOUT-MAP-LABEL: func private @not_inplace(%{{.*}}: memref<?x?xf32>) {
62+
func.func private @not_inplace(
6363
%A : tensor<?x?xf32> {bufferization.writable = true})
6464
-> tensor<?x?xf32>
6565
{
@@ -235,7 +235,7 @@ func.func @dominance_violation_bug_1(
235235

236236
// -----
237237

238-
func.func @gather_like(
238+
func.func private @gather_like(
239239
%arg0 : tensor<?x?xf32> {bufferization.writable = false},
240240
%arg1 : tensor<?xi32> {bufferization.writable = false},
241241
%arg2 : tensor<?x?xf32> {bufferization.writable = true})
@@ -254,7 +254,7 @@ func.func @gather_like(
254254
} -> tensor<?x?xf32>
255255
return %0 : tensor<?x?xf32>
256256
}
257-
// CHECK-LABEL: func @gather_like(
257+
// CHECK-LABEL: func private @gather_like(
258258
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32,
259259
// CHECK-SAME: %[[ARG1:.+]]: memref<?xi32
260260
// CHECK-SAME: %[[ARG2:.+]]: memref<?x?xf32

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
// Test bufferization using memref types that have no layout map.
99
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
1010

11-
// CHECK-LABEL: func @scf_for_yield_only(
11+
// CHECK-LABEL: func private @scf_for_yield_only(
1212
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
1313
// CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
1414
// CHECK-SAME: ) -> memref<?xf32> {
15-
func.func @scf_for_yield_only(
15+
func.func private @scf_for_yield_only(
1616
%A : tensor<?xf32> {bufferization.writable = false},
1717
%B : tensor<?xf32> {bufferization.writable = true},
1818
%lb : index, %ub : index, %step : index)
@@ -85,11 +85,11 @@ func.func @nested_scf_for(%A : tensor<?xf32> {bufferization.writable = true},
8585

8686
// -----
8787

88-
// CHECK-LABEL: func @scf_for_with_tensor.insert_slice
88+
// CHECK-LABEL: func private @scf_for_with_tensor.insert_slice
8989
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
9090
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
9191
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
92-
func.func @scf_for_with_tensor.insert_slice(
92+
func.func private @scf_for_with_tensor.insert_slice(
9393
%A : tensor<?xf32> {bufferization.writable = false},
9494
%B : tensor<?xf32> {bufferization.writable = true},
9595
%C : tensor<4xf32> {bufferization.writable = false},
@@ -471,11 +471,11 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
471471

472472
// -----
473473

474-
// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict(
474+
// CHECK-LABEL: func private @parallel_insert_slice_no_conflict(
475475
// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
476476
// CHECK-SAME: %[[arg1:.*]]: memref<?xf32, strided{{.*}}>,
477477
// CHECK-SAME: %[[arg2:.*]]: memref<?xf32, strided{{.*}}>
478-
func.func @parallel_insert_slice_no_conflict(
478+
func.func private @parallel_insert_slice_no_conflict(
479479
%idx: index,
480480
%idx2: index,
481481
%arg1: tensor<?xf32> {bufferization.writable = true},

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
// Test bufferization using memref types that have no layout map.
99
// RUN: mlir-opt %s -one-shot-bufferize="unknown-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
1010

11-
// CHECK-LABEL: func @insert_slice_fun
11+
// CHECK-LABEL: func private @insert_slice_fun
1212
// CHECK-SAME: %[[A0:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
1313
// CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
1414
// CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>,
1515
// CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
16-
func.func @insert_slice_fun(
16+
func.func private @insert_slice_fun(
1717
%A0 : tensor<?xf32> {bufferization.writable = false},
1818
%A1 : tensor<?xf32> {bufferization.writable = true},
1919
%t0 : tensor<4xf32> {bufferization.writable = false},
@@ -331,12 +331,12 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
331331
// -----
332332

333333
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)>
334-
// CHECK-LABEL: func.func @cast_retains_buffer_layout(
334+
// CHECK-LABEL: func.func private @cast_retains_buffer_layout(
335335
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
336336
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]>
337337
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>>
338338
// CHECK: return %[[slice]]
339-
func.func @cast_retains_buffer_layout(
339+
func.func private @cast_retains_buffer_layout(
340340
%t: tensor<?xf32>
341341
{bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>},
342342
%sz: index)
@@ -353,12 +353,12 @@ func.func @cast_retains_buffer_layout(
353353

354354
// -----
355355

356-
// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
356+
// CHECK-LABEL: func private @cast_retains_buffer_layout_strided(
357357
// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
358358
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
359359
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
360360
// CHECK: return %[[slice]]
361-
func.func @cast_retains_buffer_layout_strided(
361+
func.func private @cast_retains_buffer_layout_strided(
362362
%t: tensor<?xf32>
363363
{bufferization.buffer_layout = strided<[1], offset: 5>},
364364
%sz: index)

0 commit comments

Comments
 (0)