Skip to content

Commit fa087b4

Browse files
[mlir][scf][bufferize][NFC] Lookup buffer using helper function
Lookup iter_arg buffers using `lookupBuffer` instead of always creating a new `ToMemrefOp`. Also cast all yielded buffers (if necessary), regardless of whether they are an equivalent buffer or a new allocation. Note: This should have been part of D123369. Differential Revision: https://reviews.llvm.org/D123383
1 parent 8d5c8d5 commit fa087b4

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,23 @@ struct ForOpInterface
314314
auto bufferizableOp = cast<BufferizableOpInterface>(op);
315315
Block *oldLoopBody = &forOp.getLoopBody().front();
316316

317+
// Helper function for casting MemRef buffers.
318+
auto castBuffer = [&](Value buffer, Type type) {
319+
assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
320+
assert(buffer.getType().isa<BaseMemRefType>() &&
321+
"expected BaseMemRefType");
322+
// If the buffer already has the correct type, no cast is needed.
323+
if (buffer.getType() == type)
324+
return buffer;
325+
// TODO: In case `type` has a layout map that is not the fully dynamic
326+
// one, we may not be able to cast the buffer. In that case, the loop
327+
// iter_arg's layout map must be changed (see uses of `castBuffer`).
328+
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
329+
"scf.for op bufferization: cast incompatible");
330+
return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
331+
.getResult();
332+
};
333+
317334
// Indices of all iter_args that have tensor type. These are the ones that
318335
// are bufferized.
319336
DenseSet<int64_t> indices;
@@ -382,17 +399,18 @@ struct ForOpInterface
382399
rewriter.setInsertionPoint(yieldOp);
383400
SmallVector<Value> yieldValues =
384401
convert(yieldOp.getResults(), [&](Value val, int64_t index) {
385-
ensureToMemrefOpIsValid(val, initArgs[index].getType());
386-
Value yieldedVal = rewriter.create<bufferization::ToMemrefOp>(
387-
val.getLoc(), initArgs[index].getType(), val);
402+
Type initArgType = initArgs[index].getType();
403+
ensureToMemrefOpIsValid(val, initArgType);
404+
Value yieldedVal =
405+
bufferization::lookupBuffer(rewriter, val, state.getOptions());
388406

389407
if (equivalentYields[index])
390408
// Yielded value is equivalent to the corresponding iter_arg bbArg.
391409
// Yield the value directly. Most IR should be like that. Everything
392410
// else must be resolved with copies and is potentially inefficient.
393411
// By default, such problematic IR would already have been rejected
394412
// during `verifyAnalysis`, unless `allow-return-allocs`.
395-
return yieldedVal;
413+
return castBuffer(yieldedVal, initArgType);
396414

397415
// It is not certain that the yielded value and the iter_arg bbArg
398416
// have the same buffer. Allocate a new buffer and copy. The yielded
@@ -412,21 +430,9 @@ struct ForOpInterface
412430
(void)copyStatus;
413431
assert(succeeded(copyStatus) && "could not create memcpy");
414432

415-
if (yieldedVal.getType() == yieldedAlloc->getType())
416-
return *yieldedAlloc;
417-
418-
// The iter_arg memref type has a layout map. Cast the new buffer to
419-
// the same type.
420-
// TODO: In case the iter_arg has a layout map that is not the fully
421-
// dynamic one, we cannot cast the new buffer. In that case, the
422-
// iter_arg must be changed to the fully dynamic layout map. (And then
423-
// the new buffer can be casted.)
424-
assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(),
425-
yieldedVal.getType()) &&
426-
"scf.for op bufferization: cast incompatible");
427-
Value casted = rewriter.create<memref::CastOp>(
428-
val.getLoc(), yieldedVal.getType(), *yieldedAlloc);
429-
return casted;
433+
// The iter_arg memref type may have a layout map. Cast the new buffer
434+
// to the same type if needed.
435+
return castBuffer(*yieldedAlloc, initArgType);
430436
});
431437
yieldOp.getResultsMutable().assign(yieldValues);
432438

0 commit comments

Comments
 (0)