Skip to content

Commit cd1363b

Browse files
authored
[mlir][Bufferization] Support cast from ranked to unranked in canonic… (#152257)
#150511 changed the canonicalization pattern to not allow casts from ranked to unranked anymore. This patch restores this functionality, while still keeping the fix to preserve memory space and layout.
1 parent a3c386d commit cd1363b

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -806,14 +806,12 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
806806
if (!srcTensorType)
807807
return failure();
808808
auto currentOutputMemRefType =
809-
dyn_cast<MemRefType>(toBuffer.getResult().getType());
809+
dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
810810
if (!currentOutputMemRefType)
811811
return failure();
812812

813-
auto memrefType = MemRefType::get(srcTensorType.getShape(),
814-
srcTensorType.getElementType(),
815-
currentOutputMemRefType.getLayout(),
816-
currentOutputMemRefType.getMemorySpace());
813+
auto memrefType = currentOutputMemRefType.cloneWith(
814+
srcTensorType.getShape(), srcTensorType.getElementType());
817815
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
818816
tensorCastOperand.getOperand(),
819817
toBuffer.getReadOnly());

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,19 @@ func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
263263
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
264264
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
265265

266+
// CHECK-LABEL: func @tensor_cast_to_unranked_buffer
267+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
268+
func.func @tensor_cast_to_unranked_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
269+
memref<*xi8> {
270+
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<*xi8>
271+
%1 = bufferization.to_buffer %0 read_only : tensor<*xi8> to memref<*xi8>
272+
return %1 : memref<*xi8>
273+
}
274+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
275+
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
276+
// CHECK-SAME: memref<4x6x16x32xi8> to memref<*xi8>
277+
// CHECK: return %[[M1]] : memref<*xi8>
278+
266279
// -----
267280

268281
// CHECK-LABEL: func @tensor_cast_to_buffer

0 commit comments

Comments
 (0)