diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index a9ba662348a52..4ac6eca586961 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -860,6 +860,10 @@ struct ReshapeOpInterface AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { + // Only the 'source' operand aliases the result. + auto reshapeOp = cast(op); + if (reshapeOp.getSourceMutable() != opOperand) + return {}; return {{op->getOpResult(0), BufferRelation::Equivalent}}; } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index af4f84640890b..2983cd30258a5 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> { // ----- +// CHECK-LABEL: func @tensor_reshape_aliasing +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) +func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor { + %t1_static = arith.constant dense<0.> : tensor<10xf32> + // CHECK-DAG: %[[T1:.+]] = memref.cast + %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor + + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex> + %shape = bufferization.alloc_tensor() : tensor<2xindex> + // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]] + %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex> + // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]] + %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex> + + // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]]) + %reshaped = tensor.reshape %t1(%shape.1) : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[RESHAPED]] + return %reshaped : tensor +} + +// ----- + // CHECK-LABEL: @reshape_with_non_identity_layout( // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, // CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,