@@ -23,6 +23,8 @@ namespace bufferization {
23
23
using namespace mlir ;
24
24
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
25
25
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26
+ using AllocDynamicSizesMap =
27
+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
26
28
27
29
// / Return `true` if the given MemRef type has a fully dynamic layout.
28
30
static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
43
45
return type.getLayout ().isIdentity ();
44
46
}
45
47
48
+ // / Return the dynamic shapes of the `memref` based on the defining op. If the
49
+ // / complete dynamic shape fails to be captured, return an empty value.
50
+ // / Currently, only function block arguments are supported for capturing.
51
+ static SmallVector<Value> getDynamicSize (Value memref, func::FuncOp funcOp) {
52
+ Operation *defOp = memref.getDefiningOp ();
53
+ if (!defOp)
54
+ return {};
55
+ auto operands = defOp->getOperands ();
56
+ SmallVector<Value> dynamicSizes;
57
+ for (Value size : operands) {
58
+ if (!isa<IndexType>(size.getType ()))
59
+ continue ;
60
+
61
+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
62
+ if (!sizeSrc)
63
+ return {};
64
+ auto arguments = funcOp.getArguments ();
65
+ auto iter = llvm::find (arguments, sizeSrc);
66
+ if (iter == arguments.end ())
67
+ return {};
68
+ dynamicSizes.push_back (*iter);
69
+ }
70
+ return dynamicSizes;
71
+ }
72
+
73
+ // / Returns the dynamic sizes at the callee, through the call relationship
74
+ // / between the caller and callee.
75
+ static SmallVector<Value> mapDynamicSizeAtCaller (func::CallOp call,
76
+ func::FuncOp callee,
77
+ ValueRange dynamicSizes) {
78
+ SmallVector<Value> mappedDynamicSizes;
79
+ for (Value size : dynamicSizes) {
80
+ for (auto [src, dst] :
81
+ llvm::zip_first (call.getOperands (), callee.getArguments ())) {
82
+ if (size != dst)
83
+ continue ;
84
+ mappedDynamicSizes.push_back (src);
85
+ }
86
+ }
87
+ assert (mappedDynamicSizes.size () == dynamicSizes.size () &&
88
+ " could not find all dynamic sizes" );
89
+ return mappedDynamicSizes;
90
+ }
91
+
46
92
// Updates the func op and entry block.
47
93
//
48
94
// Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
109
155
// the given out-params.
110
156
static LogicalResult
111
157
updateReturnOps (func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
158
+ AllocDynamicSizesMap &map,
112
159
const bufferization::BufferResultsToOutParamsOpts &options) {
113
160
auto res = func.walk ([&](func::ReturnOp op) {
114
161
SmallVector<Value, 6 > copyIntoOutParams;
@@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
120
167
keepAsReturnOperands.push_back (operand);
121
168
}
122
169
OpBuilder builder (op);
170
+ SmallVector<SmallVector<Value>> dynamicSizes;
123
171
for (auto [orig, arg] : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
124
- if (options.hoistStaticAllocs &&
172
+ bool hoistStaticAllocs =
173
+ options.hoistStaticAllocs &&
174
+ cast<MemRefType>(orig.getType ()).hasStaticShape ();
175
+ bool hoistDynamicAllocs =
176
+ options.hoistDynamicAllocs &&
177
+ !cast<MemRefType>(orig.getType ()).hasStaticShape ();
178
+ if ((hoistStaticAllocs || hoistDynamicAllocs) &&
125
179
isa_and_nonnull<bufferization::AllocationOpInterface>(
126
- orig.getDefiningOp ()) &&
127
- mlir::cast<MemRefType>(orig.getType ()).hasStaticShape ()) {
180
+ orig.getDefiningOp ())) {
128
181
orig.replaceAllUsesWith (arg);
182
+ if (hoistDynamicAllocs) {
183
+ SmallVector<Value> dynamicSize = getDynamicSize (orig, func);
184
+ dynamicSizes.push_back (dynamicSize);
185
+ }
129
186
orig.getDefiningOp ()->erase ();
130
187
} else {
131
188
if (failed (options.memCpyFn (builder, op.getLoc (), orig, arg)))
@@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
134
191
}
135
192
func::ReturnOp::create (builder, op.getLoc (), keepAsReturnOperands);
136
193
op.erase ();
194
+ auto dynamicSizePair =
195
+ std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196
+ dynamicSizes);
197
+ map.insert (dynamicSizePair);
137
198
return WalkResult::advance ();
138
199
});
139
200
return failure (res.wasInterrupted ());
@@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
142
203
// Updates all CallOps in the scope of the given ModuleOp by allocating
143
204
// temporary buffers for newly introduced out params.
144
205
static LogicalResult
145
- updateCalls (ModuleOp module ,
206
+ updateCalls (ModuleOp module , const AllocDynamicSizesMap &map,
146
207
const bufferization::BufferResultsToOutParamsOpts &options) {
147
208
bool didFail = false ;
148
209
SymbolTable symtab (module );
@@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
166
227
}
167
228
SmallVector<Value, 6 > outParams;
168
229
OpBuilder builder (op);
230
+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup (callee);
231
+ size_t dynamicSizesIndex = 0 ;
169
232
for (Value memref : replaceWithOutParams) {
170
- if (!cast<MemRefType>(memref.getType ()).hasStaticShape ()) {
233
+ SmallVector<Value> dynamicSize = dynamicSizes.size () > dynamicSizesIndex
234
+ ? dynamicSizes[dynamicSizesIndex]
235
+ : SmallVector<Value>();
236
+ bool memrefStaticShape =
237
+ cast<MemRefType>(memref.getType ()).hasStaticShape ();
238
+ if (!memrefStaticShape && dynamicSize.empty ()) {
171
239
op.emitError ()
172
240
<< " cannot create out param for dynamically shaped result" ;
173
241
didFail = true ;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
177
245
auto allocType =
178
246
MemRefType::get (memrefType.getShape (), memrefType.getElementType (),
179
247
AffineMap (), memrefType.getMemorySpace ());
248
+
249
+ if (memrefStaticShape) {
250
+ dynamicSize = {};
251
+ } else {
252
+ ++dynamicSizesIndex;
253
+ dynamicSize = mapDynamicSizeAtCaller (op, callee, dynamicSize);
254
+ }
180
255
auto maybeOutParam =
181
- options.allocationFn (builder, op.getLoc (), allocType);
256
+ options.allocationFn (builder, op.getLoc (), allocType, dynamicSize );
182
257
if (failed (maybeOutParam)) {
183
258
op.emitError () << " failed to create allocation op" ;
184
259
didFail = true ;
@@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
213
288
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
214
289
ModuleOp module ,
215
290
const bufferization::BufferResultsToOutParamsOpts &options) {
291
+ // It maps the shape source of the dynamic shape memref returned by each
292
+ // function.
293
+ AllocDynamicSizesMap map;
216
294
for (auto func : module .getOps <func::FuncOp>()) {
217
295
if (!options.filterFn (&func))
218
296
continue ;
@@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
222
300
return failure ();
223
301
if (func.isExternal ())
224
302
continue ;
225
- if (failed (updateReturnOps (func, appendedEntryArgs, options))) {
303
+ if (failed (updateReturnOps (func, appendedEntryArgs, map, options))) {
226
304
return failure ();
227
305
}
228
306
}
229
- if (failed (updateCalls (module , options)))
307
+ if (failed (updateCalls (module , map, options)))
230
308
return failure ();
231
309
return success ();
232
310
}
@@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
243
321
options.addResultAttribute = true ;
244
322
if (hoistStaticAllocs)
245
323
options.hoistStaticAllocs = true ;
324
+ if (hoistDynamicAllocs)
325
+ options.hoistDynamicAllocs = true ;
246
326
247
327
if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
248
328
options)))
0 commit comments