@@ -314,6 +314,23 @@ struct ForOpInterface
314
314
auto bufferizableOp = cast<BufferizableOpInterface>(op);
315
315
Block *oldLoopBody = &forOp.getLoopBody ().front ();
316
316
317
+ // Helper function for casting MemRef buffers.
318
+ auto castBuffer = [&](Value buffer, Type type) {
319
+ assert (type.isa <BaseMemRefType>() && " expected BaseMemRefType" );
320
+ assert (buffer.getType ().isa <BaseMemRefType>() &&
321
+ " expected BaseMemRefType" );
322
+ // If the buffer already has the correct type, no cast is needed.
323
+ if (buffer.getType () == type)
324
+ return buffer;
325
+ // TODO: In case `type` has a layout map that is not the fully dynamic
326
+ // one, we may not be able to cast the buffer. In that case, the loop
327
+ // iter_arg's layout map must be changed (see uses of `castBuffer`).
328
+ assert (memref::CastOp::areCastCompatible (buffer.getType (), type) &&
329
+ " scf.for op bufferization: cast incompatible" );
330
+ return rewriter.create <memref::CastOp>(buffer.getLoc (), type, buffer)
331
+ .getResult ();
332
+ };
333
+
317
334
// Indices of all iter_args that have tensor type. These are the ones that
318
335
// are bufferized.
319
336
DenseSet<int64_t > indices;
@@ -382,17 +399,18 @@ struct ForOpInterface
382
399
rewriter.setInsertionPoint (yieldOp);
383
400
SmallVector<Value> yieldValues =
384
401
convert (yieldOp.getResults (), [&](Value val, int64_t index) {
385
- ensureToMemrefOpIsValid (val, initArgs[index].getType ());
386
- Value yieldedVal = rewriter.create <bufferization::ToMemrefOp>(
387
- val.getLoc (), initArgs[index].getType (), val);
402
+ Type initArgType = initArgs[index].getType ();
403
+ ensureToMemrefOpIsValid (val, initArgType);
404
+ Value yieldedVal =
405
+ bufferization::lookupBuffer (rewriter, val, state.getOptions ());
388
406
389
407
if (equivalentYields[index])
390
408
// Yielded value is equivalent to the corresponding iter_arg bbArg.
391
409
// Yield the value directly. Most IR should be like that. Everything
392
410
// else must be resolved with copies and is potentially inefficient.
393
411
// By default, such problematic IR would already have been rejected
394
412
// during `verifyAnalysis`, unless `allow-return-allocs`.
395
- return yieldedVal;
413
+ return castBuffer ( yieldedVal, initArgType) ;
396
414
397
415
// It is not certain that the yielded value and the iter_arg bbArg
398
416
// have the same buffer. Allocate a new buffer and copy. The yielded
@@ -412,21 +430,9 @@ struct ForOpInterface
412
430
(void )copyStatus;
413
431
assert (succeeded (copyStatus) && " could not create memcpy" );
414
432
415
- if (yieldedVal.getType () == yieldedAlloc->getType ())
416
- return *yieldedAlloc;
417
-
418
- // The iter_arg memref type has a layout map. Cast the new buffer to
419
- // the same type.
420
- // TODO: In case the iter_arg has a layout map that is not the fully
421
- // dynamic one, we cannot cast the new buffer. In that case, the
422
- // iter_arg must be changed to the fully dynamic layout map. (And then
423
- // the new buffer can be casted.)
424
- assert (memref::CastOp::areCastCompatible (yieldedAlloc->getType (),
425
- yieldedVal.getType ()) &&
426
- " scf.for op bufferization: cast incompatible" );
427
- Value casted = rewriter.create <memref::CastOp>(
428
- val.getLoc (), yieldedVal.getType (), *yieldedAlloc);
429
- return casted;
433
+ // The iter_arg memref type may have a layout map. Cast the new buffer
434
+ // to the same type if needed.
435
+ return castBuffer (*yieldedAlloc, initArgType);
430
436
});
431
437
yieldOp.getResultsMutable ().assign (yieldValues);
432
438
0 commit comments