From a492cf9a002bbd3116d513930c3917f23f42e3d7 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 5 May 2025 19:25:17 +0000 Subject: [PATCH] [mlir][bufferization] Let bufferization.tensor_layout be any layout attr The bufferization.tensor_layout is unnecessarily restricted to affine map attributes when it could reasonably be any implementor of MemRefLayoutAttrInterface. --- .../Bufferization/IR/BufferizationDialect.cpp | 4 ++-- .../FuncBufferizableOpInterfaceImpl.cpp | 8 +++---- .../Dialect/Tensor/one-shot-bufferize.mlir | 22 +++++++++++++++++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 6b9253a5d71da..d8eac01c2dea0 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -122,9 +122,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute( return success(); } if (attr.getName() == kBufferLayoutAttrName) { - if (!llvm::isa(attr.getValue())) { + if (!llvm::isa(attr.getValue())) { return op->emitError() << "'" << kBufferLayoutAttrName - << "' is expected to be a affine map attribute"; + << "' is expected to be a memref layout attribute"; } if (!isa(op)) return op->emitError() << "expected '" << kBufferLayoutAttrName diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..0b0dcc9162a9a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -63,16 +63,16 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, BaseMemRefType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - auto layoutAttr = funcOp.getArgAttrOfType( + auto layoutAttr = funcOp.getArgAttrOfType( index, BufferizationDialect::kBufferLayoutAttrName); if (!layoutAttr) return memrefType; auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get( - rankedMemrefType.getShape(), rankedMemrefType.getElementType(), - layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); + return MemRefType::get(rankedMemrefType.getShape(), + rankedMemrefType.getElementType(), layoutAttr, + rankedMemrefType.getMemorySpace()); } /// Return the FuncOp called by `callOp`. diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 2983cd30258a5..5f95da25cbc74 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -353,6 +353,28 @@ func.func @cast_retains_buffer_layout( // ----- +// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided( +// CHECK-SAME: %[[t:.*]]: memref>, %[[sz:.*]]: index) -> memref> { +// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref> to memref<10xf32, strided<[1], offset: 5>> +// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref> +// CHECK: return %[[slice]] +func.func @cast_retains_buffer_layout_strided( + %t: tensor + {bufferization.buffer_layout = strided<[1], offset: 5>}, + %sz: index) + -> (tensor<10xf32>, tensor) +{ + %casted = tensor.cast %t : tensor to tensor<10xf32> + %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor + + // Note: The %casted return type is folded away because both buffers are + // equivalent. Therefore, we currently loose some static type information + // in the caller. + return %casted, %slice : tensor<10xf32>, tensor +} + +// ----- + // CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) { %c0 = arith.constant 0 : index