Skip to content

Commit 48867b0

Browse files
[mlir][tensor] Fix bufferization interface for 'tensor.reshape'
Previously, the BufferizableOpInterface implementation for 'tensor.reshape' listed the 'shape' operand as an alias for the result tensor, causing unnecessary conflicts with ops that "write" to the shape operand.
1 parent e5ce030 commit 48867b0

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
862862

863863
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
864864
const AnalysisState &state) const {
865+
// Only the 'source' operand aliases the result.
866+
auto reshapeOp = cast<tensor::ReshapeOp>(op);
867+
if (reshapeOp.getSourceMutable() != opOperand)
868+
return {};
865869
return {{op->getOpResult(0), BufferRelation::Equivalent}};
866870
}
867871

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
398398

399399
// -----
400400

401+
// CHECK-LABEL: func @tensor_reshape_aliasing
402+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
403+
func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
404+
%t1_static = arith.constant dense<0.> : tensor<10xf32>
405+
// CHECK-DAG: %[[T1:.+]] = memref.cast
406+
%t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
407+
408+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
409+
%c0 = arith.constant 0 : index
410+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
411+
%c1 = arith.constant 1 : index
412+
413+
// CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
414+
%shape = bufferization.alloc_tensor() : tensor<2xindex>
415+
// CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
416+
%shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
417+
// CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
418+
%shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
419+
420+
// CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
421+
%reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
422+
// CHECK: return %[[RESHAPED]]
423+
return %reshaped : tensor<?x?xf32>
424+
}
425+
426+
// -----
427+
401428
// CHECK-LABEL: @reshape_with_non_identity_layout(
402429
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
403430
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,

0 commit comments

Comments
 (0)