@@ -23,6 +23,8 @@ namespace bufferization {
2323using namespace mlir ;
2424using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2525using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26+ using AllocDynamicSizesMap =
27+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
2628
2729// / Return `true` if the given MemRef type has a fully dynamic layout.
2830static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -43,30 +45,24 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4345 return type.getLayout ().isIdentity ();
4446}
4547
46- // / Return the dynamic shapes of the `memref` based on the define op. If the
48+ // / Return the dynamic shapes of the `memref` based on the defining op. If the
4749// / complete dynamic shape fails to be captured, return an empty value.
48- // / Currently, only function parameters are supported for capturing.
50+ // / Currently, only function block arguments are supported for capturing.
4951static SmallVector<Value> getDynamicSize (Value memref, func::FuncOp funcOp) {
50- auto *defOp = memref.getDefiningOp ();
52+ Operation *defOp = memref.getDefiningOp ();
5153 if (!defOp)
5254 return {};
5355 auto operands = defOp->getOperands ();
5456 SmallVector<Value> dynamicSizes;
5557 for (Value size : operands) {
56- BlockArgument sizeSrc = mlir:: dyn_cast<BlockArgument>(size);
58+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
5759 if (!sizeSrc)
5860 return {};
5961
60- bool finded = false ;
61- for (BlockArgument argument : funcOp.getArguments ()) {
62- if (argument == sizeSrc) {
63- dynamicSizes.push_back (argument);
64- finded = true ;
65- break ;
66- }
67- }
68- if (!finded)
62+ auto iter = llvm::find (funcOp.getArguments (), sizeSrc);
63+ if (!iter)
6964 return {};
65+ dynamicSizes.push_back (*iter);
7066 }
7167 return dynamicSizes;
7268}
@@ -76,18 +72,20 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
7672static SmallVector<Value> mapDynamicSizeAtCaller (func::CallOp call,
7773 func::FuncOp callee,
7874 ValueRange dynamicSizes) {
79- SmallVector<Value> mapedDynamicSizes ;
75+ SmallVector<Value> mappedDynamicSizes ;
8076 for (Value size : dynamicSizes) {
8177 auto callOperands = call.getOperands ();
8278 for (size_t i = 0 , e = callOperands.size (); i < e; ++i) {
8379 Value src = callOperands[i];
8480 BlockArgument dst = callee.getArgument (i);
8581 if (size != dst)
8682 continue ;
87- mapedDynamicSizes .push_back (src);
83+ mappedDynamicSizes .push_back (src);
8884 }
8985 }
90- return mapedDynamicSizes;
86+ assert (mappedDynamicSizes.size () == dynamicSizes.size () &&
87+ " could not find all dynamic sizes" );
88+ return mappedDynamicSizes;
9189}
9290
9391// Updates the func op and entry block.
@@ -156,7 +154,8 @@ updateFuncOp(func::FuncOp func,
156154// the given out-params.
157155static LogicalResult
158156updateReturnOps (func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
159- bufferization::BufferResultsToOutParamsOpts &options) {
157+ AllocDynamicSizesMap &map,
158+ const bufferization::BufferResultsToOutParamsOpts &options) {
160159 auto res = func.walk ([&](func::ReturnOp op) {
161160 SmallVector<Value, 6 > copyIntoOutParams;
162161 SmallVector<Value, 6 > keepAsReturnOperands;
@@ -171,10 +170,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
171170 for (auto [orig, arg] : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
172171 bool hoistStaticAllocs =
173172 options.hoistStaticAllocs &&
174- mlir:: cast<MemRefType>(orig.getType ()).hasStaticShape ();
173+ cast<MemRefType>(orig.getType ()).hasStaticShape ();
175174 bool hoistDynamicAllocs =
176175 options.hoistDynamicAllocs &&
177- !mlir:: cast<MemRefType>(orig.getType ()).hasStaticShape ();
176+ !cast<MemRefType>(orig.getType ()).hasStaticShape ();
178177 if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179178 isa_and_nonnull<bufferization::AllocationOpInterface>(
180179 orig.getDefiningOp ())) {
@@ -194,7 +193,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
194193 auto dynamicSizePair =
195194 std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196195 dynamicSizes);
197- options. dynamicSizesMap .insert (dynamicSizePair);
196+ map .insert (dynamicSizePair);
198197 return WalkResult::advance ();
199198 });
200199 return failure (res.wasInterrupted ());
@@ -203,7 +202,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
203202// Updates all CallOps in the scope of the given ModuleOp by allocating
204203// temporary buffers for newly introduced out params.
205204static LogicalResult
206- updateCalls (ModuleOp module ,
205+ updateCalls (ModuleOp module , AllocDynamicSizesMap &map,
207206 const bufferization::BufferResultsToOutParamsOpts &options) {
208207 bool didFail = false ;
209208 SymbolTable symtab (module );
@@ -227,8 +226,7 @@ updateCalls(ModuleOp module,
227226 }
228227 SmallVector<Value, 6 > outParams;
229228 OpBuilder builder (op);
230- SmallVector<SmallVector<Value>> dynamicSizes =
231- options.dynamicSizesMap .lookup (callee);
229+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup (callee);
232230 size_t dynamicSizesIndex = 0 ;
233231 for (Value memref : replaceWithOutParams) {
234232 SmallVector<Value> dynamicSize = dynamicSizes.size () > dynamicSizesIndex
@@ -287,7 +285,11 @@ updateCalls(ModuleOp module,
287285}
288286
289287LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
290- ModuleOp module , bufferization::BufferResultsToOutParamsOpts &options) {
288+ ModuleOp module ,
289+ const bufferization::BufferResultsToOutParamsOpts &options) {
290+ // / It maps the shape source of the dynamic shape memref returned by each
291+ // / function.
292+ AllocDynamicSizesMap map;
291293 for (auto func : module .getOps <func::FuncOp>()) {
292294 if (!options.filterFn (&func))
293295 continue ;
@@ -297,11 +299,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
297299 return failure ();
298300 if (func.isExternal ())
299301 continue ;
300- if (failed (updateReturnOps (func, appendedEntryArgs, options))) {
302+ if (failed (updateReturnOps (func, appendedEntryArgs, map, options))) {
301303 return failure ();
302304 }
303305 }
304- if (failed (updateCalls (module , options)))
306+ if (failed (updateCalls (module , map, options)))
305307 return failure ();
306308 return success ();
307309}
0 commit comments