Skip to content

Commit 1087c10

Browse files
[mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params (#160985)
Add hoist-dynamic-allocs-option to buffer-results-to-out-params. This PR supported that obtain the size of the dynamic shape memref through the caller-callee relationship.
1 parent 1bd9c1b commit 1087c10

File tree

5 files changed

+178
-12
lines changed

5 files changed

+178
-12
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ struct BufferResultsToOutParamsOpts {
131131
/// Allocator function: Generate a memref allocation with the given type.
132132
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
133133
/// results, we don't allow passing a range of values for dynamic dims.
134-
using AllocationFn =
135-
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
134+
using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
135+
MemRefType, ValueRange)>;
136136

137137
/// Memcpy function: Generate a memcpy between two memrefs.
138138
using MemCpyFn =
@@ -147,8 +147,9 @@ struct BufferResultsToOutParamsOpts {
147147
/// Allocation function; used to allocate a memref.
148148
/// Default memref.alloc is used
149149
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
150-
MemRefType type) {
151-
return memref::AllocOp::create(builder, loc, type).getResult();
150+
MemRefType type, ValueRange dynamicSizes) {
151+
return memref::AllocOp::create(builder, loc, type, dynamicSizes)
152+
.getResult();
152153
};
153154

154155
/// Memcpy function; used to create a copy between two memrefs.
@@ -166,6 +167,10 @@ struct BufferResultsToOutParamsOpts {
166167
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
167168
/// memref is allocated in the current function.
168169
bool hoistStaticAllocs = false;
170+
171+
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
172+
/// memref is allocated in the current function and has dynamic shape.
173+
bool hoistDynamicAllocs = false;
169174
};
170175

171176
/// Replace buffers that are returned from a function with an out parameter.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass
256256
"Add the attribute 'bufferize.result' to all output parameters.">,
257257
Option<"hoistStaticAllocs", "hoist-static-allocs", "bool",
258258
/*default=*/"false", "Hoist static allocations to call sites.">,
259+
Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
260+
/*default=*/"false", "Hoist dynamic allocations to call sites.">,
259261
];
260262
let dependentDialects = ["memref::MemRefDialect"];
261263
}

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace bufferization {
2323
using namespace mlir;
2424
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2525
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26+
using AllocDynamicSizesMap =
27+
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
2628

2729
/// Return `true` if the given MemRef type has a fully dynamic layout.
2830
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4345
return type.getLayout().isIdentity();
4446
}
4547

48+
/// Return the dynamic shapes of the `memref` based on the defining op. If the
49+
/// complete dynamic shape fails to be captured, return an empty value.
50+
/// Currently, only function block arguments are supported for capturing.
51+
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
52+
Operation *defOp = memref.getDefiningOp();
53+
if (!defOp)
54+
return {};
55+
auto operands = defOp->getOperands();
56+
SmallVector<Value> dynamicSizes;
57+
for (Value size : operands) {
58+
if (!isa<IndexType>(size.getType()))
59+
continue;
60+
61+
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
62+
if (!sizeSrc)
63+
return {};
64+
auto arguments = funcOp.getArguments();
65+
auto iter = llvm::find(arguments, sizeSrc);
66+
if (iter == arguments.end())
67+
return {};
68+
dynamicSizes.push_back(*iter);
69+
}
70+
return dynamicSizes;
71+
}
72+
73+
/// Returns the dynamic sizes at the callee, through the call relationship
74+
/// between the caller and callee.
75+
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
76+
func::FuncOp callee,
77+
ValueRange dynamicSizes) {
78+
SmallVector<Value> mappedDynamicSizes;
79+
for (Value size : dynamicSizes) {
80+
for (auto [src, dst] :
81+
llvm::zip_first(call.getOperands(), callee.getArguments())) {
82+
if (size != dst)
83+
continue;
84+
mappedDynamicSizes.push_back(src);
85+
}
86+
}
87+
assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88+
"could not find all dynamic sizes");
89+
return mappedDynamicSizes;
90+
}
91+
4692
// Updates the func op and entry block.
4793
//
4894
// Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
109155
// the given out-params.
110156
static LogicalResult
111157
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
158+
AllocDynamicSizesMap &map,
112159
const bufferization::BufferResultsToOutParamsOpts &options) {
113160
auto res = func.walk([&](func::ReturnOp op) {
114161
SmallVector<Value, 6> copyIntoOutParams;
@@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
120167
keepAsReturnOperands.push_back(operand);
121168
}
122169
OpBuilder builder(op);
170+
SmallVector<SmallVector<Value>> dynamicSizes;
123171
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
124-
if (options.hoistStaticAllocs &&
172+
bool hoistStaticAllocs =
173+
options.hoistStaticAllocs &&
174+
cast<MemRefType>(orig.getType()).hasStaticShape();
175+
bool hoistDynamicAllocs =
176+
options.hoistDynamicAllocs &&
177+
!cast<MemRefType>(orig.getType()).hasStaticShape();
178+
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
125179
isa_and_nonnull<bufferization::AllocationOpInterface>(
126-
orig.getDefiningOp()) &&
127-
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
180+
orig.getDefiningOp())) {
128181
orig.replaceAllUsesWith(arg);
182+
if (hoistDynamicAllocs) {
183+
SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
184+
dynamicSizes.push_back(dynamicSize);
185+
}
129186
orig.getDefiningOp()->erase();
130187
} else {
131188
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
@@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
134191
}
135192
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
136193
op.erase();
194+
auto dynamicSizePair =
195+
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196+
dynamicSizes);
197+
map.insert(dynamicSizePair);
137198
return WalkResult::advance();
138199
});
139200
return failure(res.wasInterrupted());
@@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
142203
// Updates all CallOps in the scope of the given ModuleOp by allocating
143204
// temporary buffers for newly introduced out params.
144205
static LogicalResult
145-
updateCalls(ModuleOp module,
206+
updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
146207
const bufferization::BufferResultsToOutParamsOpts &options) {
147208
bool didFail = false;
148209
SymbolTable symtab(module);
@@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
166227
}
167228
SmallVector<Value, 6> outParams;
168229
OpBuilder builder(op);
230+
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
231+
size_t dynamicSizesIndex = 0;
169232
for (Value memref : replaceWithOutParams) {
170-
if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
233+
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
234+
? dynamicSizes[dynamicSizesIndex]
235+
: SmallVector<Value>();
236+
bool memrefStaticShape =
237+
cast<MemRefType>(memref.getType()).hasStaticShape();
238+
if (!memrefStaticShape && dynamicSize.empty()) {
171239
op.emitError()
172240
<< "cannot create out param for dynamically shaped result";
173241
didFail = true;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
177245
auto allocType =
178246
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
179247
AffineMap(), memrefType.getMemorySpace());
248+
249+
if (memrefStaticShape) {
250+
dynamicSize = {};
251+
} else {
252+
++dynamicSizesIndex;
253+
dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
254+
}
180255
auto maybeOutParam =
181-
options.allocationFn(builder, op.getLoc(), allocType);
256+
options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
182257
if (failed(maybeOutParam)) {
183258
op.emitError() << "failed to create allocation op";
184259
didFail = true;
@@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
213288
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
214289
ModuleOp module,
215290
const bufferization::BufferResultsToOutParamsOpts &options) {
291+
// It maps the shape source of the dynamic shape memref returned by each
292+
// function.
293+
AllocDynamicSizesMap map;
216294
for (auto func : module.getOps<func::FuncOp>()) {
217295
if (!options.filterFn(&func))
218296
continue;
@@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
222300
return failure();
223301
if (func.isExternal())
224302
continue;
225-
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
303+
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
226304
return failure();
227305
}
228306
}
229-
if (failed(updateCalls(module, options)))
307+
if (failed(updateCalls(module, map, options)))
230308
return failure();
231309
return success();
232310
}
@@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
243321
options.addResultAttribute = true;
244322
if (hoistStaticAllocs)
245323
options.hoistStaticAllocs = true;
324+
if (hoistDynamicAllocs)
325+
options.hoistDynamicAllocs = true;
246326

247327
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
248328
options)))
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s
2+
3+
func.func private @single_alloc(%size : index) -> (memref<?xf32>) {
4+
%alloc = memref.alloc(%size) : memref<?xf32>
5+
return %alloc : memref<?xf32>
6+
}
7+
8+
func.func @single_alloc_test(%size : index) {
9+
%alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>)
10+
"test.sink"(%alloc) : (memref<?xf32>) -> ()
11+
}
12+
13+
// CHECK-LABEL: func.func private @single_alloc(
14+
// CHECK-SAME: %{{.*}}: index,
15+
// CHECK-SAME: %{{.*}}: memref<?xf32>) {
16+
17+
// CHECK-LABEL: func.func @single_alloc_test(
18+
// CHECK-SAME: %[[size:.*]]: index) {
19+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32>
20+
// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> ()
21+
// CHECK: "test.sink"(%[[alloc]]) : (memref<?xf32>) -> ()
22+
// CHECK: }
23+
24+
// -----
25+
26+
func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) {
27+
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
28+
%alloc1 = memref.alloc(%size1) : memref<?xf32>
29+
return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32>
30+
}
31+
32+
func.func @mult_alloc_test(%size0 : index, %size1: index) {
33+
%alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>)
34+
"test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> ()
35+
}
36+
37+
// CHECK-LABEL: func private @mult_alloc(
38+
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
39+
// CHECK-SAME: %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) {
40+
41+
// CHECK-LABEL: func @mult_alloc_test(
42+
// CHECK-SAME: %[[size0:.*]]: index,
43+
// CHECK-SAME: %[[size1:.*]]: index) {
44+
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
45+
// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
46+
// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> ()
47+
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> ()
48+
// CHECK: }
49+
50+
51+
// -----
52+
53+
func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) {
54+
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
55+
%alloc1 = memref.alloc() : memref<4xf32>
56+
%alloc2 = memref.alloc(%size1) : memref<?xf32>
57+
return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32>
58+
}
59+
60+
func.func @complex_alloc_test(%size0 : index, %size1: index) {
61+
%alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>)
62+
"test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
63+
}
64+
65+
// CHECK-LABEL: func private @complex_alloc(
66+
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
67+
// CHECK-SAME: %{{.*}}: memref<?x?xf32>,
68+
// CHECK-SAME: %{{.*}}: memref<4xf32>,
69+
// CHECK-SAME: %{{.*}}: memref<?xf32>) {
70+
71+
// CHECK-LABEL: func @complex_alloc_test(
72+
// CHECK-SAME: %[[size0:.*]]: index,
73+
// CHECK-SAME: %[[size1:.*]]: index) {
74+
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
75+
// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32>
76+
// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
77+
// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
78+
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
79+
// CHECK: }
File renamed without changes.

0 commit comments

Comments
 (0)