Skip to content

Conversation

@krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented May 5, 2025

The bufferization.tensor_layout is unnecessarily restricted to affine map attributes when it could reasonably be any implementor of MemRefLayoutAttrInterface.

The bufferization.tensor_layout is unnecessarily restricted to
affine map attributes when it could reasonably be any implementor
of MemRefLayoutAttrInterface.
@krzysz00 krzysz00 requested review from antiagainst and tpopp May 5, 2025 19:31
@llvmbot llvmbot added mlir mlir:tensor mlir:bufferization Bufferization infrastructure labels May 5, 2025
@llvmbot
Copy link
Member

llvmbot commented May 5, 2025

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The bufferization.tensor_layout is unnecessarily restricted to affine map attributes when it could reasonably be any implementor of MemRefLayoutAttrInterface.


Full diff: https://github.com/llvm/llvm-project/pull/138567.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+22)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6b9253a5d71da..d8eac01c2dea0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -122,9 +122,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
     return success();
   }
   if (attr.getName() == kBufferLayoutAttrName) {
-    if (!llvm::isa<AffineMapAttr>(attr.getValue())) {
+    if (!llvm::isa<MemRefLayoutAttrInterface>(attr.getValue())) {
       return op->emitError() << "'" << kBufferLayoutAttrName
-                             << "' is expected to be a affine map attribute";
+                             << "' is expected to be a memref layout attribute";
     }
     if (!isa<FunctionOpInterface>(op))
       return op->emitError() << "expected '" << kBufferLayoutAttrName
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..0b0dcc9162a9a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -63,16 +63,16 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
       tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
 
-  auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+  auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
       index, BufferizationDialect::kBufferLayoutAttrName);
   if (!layoutAttr)
     return memrefType;
 
   auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
-  return MemRefType::get(
-      rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
-      layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+  return MemRefType::get(rankedMemrefType.getShape(),
+                         rankedMemrefType.getElementType(), layoutAttr,
+                         rankedMemrefType.getMemorySpace());
 }
 
 /// Return the FuncOp called by `callOp`.
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 2983cd30258a5..5f95da25cbc74 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -353,6 +353,28 @@ func.func @cast_retains_buffer_layout(
 
 // -----
 
+// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
+//  CHECK-SAME:     %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
+//       CHECK:   %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
+//       CHECK:   %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
+//       CHECK:   return %[[slice]]
+func.func @cast_retains_buffer_layout_strided(
+    %t: tensor<?xf32>
+        {bufferization.buffer_layout = strided<[1], offset: 5>},
+    %sz: index)
+  -> (tensor<10xf32>, tensor<?xf32>)
+{
+  %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+  %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
+
+  // Note: The %casted return type is folded away because both buffers are
+  // equivalent. Therefore, we currently loose some static type information
+  // in the caller.
+  return %casted, %slice : tensor<10xf32>, tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
 func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
   %c0 = arith.constant 0 : index

@krzysz00 krzysz00 merged commit 6fc092f into llvm:main May 6, 2025
15 checks passed
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…ttr (llvm#138567)

The bufferization.tensor_layout is unnecessarily restricted to affine
map attributes when it could reasonably be any implementor of
MemRefLayoutAttrInterface.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir:tensor mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants