Skip to content

Commit bc19531

Browse files
qedawkinsgithub-actions[bot]
authored andcommitted
Automerge: [mlir][Bufferization] Fix to_buffer(tensor.cast) folder (#150511)
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast. The `read_only` attribute was also dropped by the pattern and is retained now as well.
2 parents a070cb1 + fd8f69d commit bc19531

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
805805
tensorCastOperand.getOperand().getType());
806806
if (!srcTensorType)
807807
return failure();
808+
auto currentOutputMemRefType =
809+
dyn_cast<MemRefType>(toBuffer.getResult().getType());
810+
if (!currentOutputMemRefType)
811+
return failure();
812+
808813
auto memrefType = MemRefType::get(srcTensorType.getShape(),
809-
srcTensorType.getElementType());
814+
srcTensorType.getElementType(),
815+
currentOutputMemRefType.getLayout(),
816+
currentOutputMemRefType.getMemorySpace());
810817
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
811-
tensorCastOperand.getOperand());
818+
tensorCastOperand.getOperand(),
819+
toBuffer.getReadOnly());
812820
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
813821
memref);
814822
return success();

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
255255
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
256256
memref<?x?x16x32xi8> {
257257
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
258-
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
258+
%1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
259259
return %1 : memref<?x?x16x32xi8>
260260
}
261-
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
261+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
262262
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
263263
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
264264
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
265265

266266
// -----
267267

268+
// CHECK-LABEL: func @tensor_cast_to_buffer
269+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
270+
func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
271+
memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
272+
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
273+
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
274+
return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
275+
}
276+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
277+
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
278+
// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
279+
// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
280+
// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
281+
282+
// -----
283+
268284
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
269285
// CHECK-LABEL: func @load_from_buffer_cast(
270286
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,

0 commit comments

Comments
 (0)