Skip to content

Commit 7c06f63

Browse files
[mlir][tensor][bufferize] Fix dealloc placement in scf.forall op
The terminator of this op is special: it does not just yield a value, but bufferizes to a memcpy. This requires special treatment to make sure that deallocs are placed after the memcpy. (By default, deallocs are placed right before the terminator.) Differential Revision: https://reviews.llvm.org/D148408
1 parent 4889214 commit 7c06f63

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,21 @@ struct ParallelInsertSliceOpInterface
10581058
*srcBuffer, subview)))
10591059
return failure();
10601060

1061+
// In case the source was allocated in the same block, make sure that the
1062+
// deallocation op (if any) appears after the memcpy. By default, deallocs
1063+
// are placed before the terminator, but this does not work for ForallOp
1064+
// because the terminator does more than just yielding a value.
1065+
//
1066+
// Note: This is not a problem for the destination buffer because these are
1067+
// assumed to always bufferize in-place.
1068+
for (Operation *user : srcBuffer->getUsers()) {
1069+
if (hasEffect<MemoryEffects::Free>(user)) {
1070+
if (user->getBlock() == parallelCombiningParent->getBlock())
1071+
user->moveBefore(user->getBlock()->getTerminator());
1072+
break;
1073+
}
1074+
}
1075+
10611076
// Delete the op.
10621077
rewriter.eraseOp(op);
10631078
return success();

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -
335335

336336
// CHECK-LABEL: func @dim_not_reading(
337337
// CHECK-SAME: %[[t:.*]]: memref<?xf32
338-
func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
338+
func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
339339
-> (tensor<?xf32>, index)
340340
{
341341
%c0 = arith.constant 0 : index
@@ -370,3 +370,31 @@ func.func @cast_retains_buffer_layout(
370370
// in the caller.
371371
return %casted, %slice : tensor<10xf32>, tensor<?xf32>
372372
}
373+
374+
// -----
375+
376+
// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
377+
func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
378+
%c0 = arith.constant 0 : index
379+
%c1 = arith.constant 1 : index
380+
%num_threads = arith.constant 50 : index
381+
382+
// CHECK: scf.forall {{.*}} {
383+
%result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<100xf32> {
384+
// The tensor.insert must bufferize out-of-place.
385+
// CHECK: memref.alloc
386+
// CHECK: memref.store
387+
%insert = tensor.insert %f into %in[%c0] : tensor<1xf32>
388+
%r = tensor.extract %in[%c0] : tensor<1xf32>
389+
vector.print %r : f32
390+
391+
// CHECK: memref.copy
392+
// CHECK: memref.dealloc
393+
scf.forall.in_parallel {
394+
tensor.parallel_insert_slice %insert into %o[%thread_idx][1][1] :
395+
tensor<1xf32> into tensor<100xf32>
396+
}
397+
}
398+
// CHECK: }
399+
return
400+
}

0 commit comments

Comments
 (0)