Skip to content

Conversation

@andrey-golubev
Copy link
Contributor

@andrey-golubev andrey-golubev commented Sep 29, 2025

Support custom types (4/N): test that it is possible to customize memref
layout specification for custom operations and function boundaries.

This is purely a test setup (no API modifications) to ensure users are
able to pass information from tensors to memrefs within bufferization
process. To achieve this, a test pass is required (since bufferization
options have to be set manually). As there is already a
--test-one-shot-module-bufferize pass present, it is extended for the
purpose.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Sep 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2025

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (4/N): allow user-specified bufferization of tensor encoding into memref layout.

Both tensor encoding and memref layout could be user-specified attributes to store arbitrary information. It is often the case that this information has to be preserved during tensor -> memref bufferization. Thus, provide an option function to create memref layout during TensorType::getBufferType() execution.

As a drive by, update AllocTensorOp::getBufferType() to work via TensorLikeType::getBufferType() when memref layout is user-specified.


Full diff: https://github.com/llvm/llvm-project/pull/161166.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+9)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+4-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+9)
  • (modified) mlir/unittests/Transforms/CMakeLists.txt (+4-1)
  • (added) mlir/unittests/Transforms/OneShotBufferization.cpp (+171)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
+  /// Construct a MemRefLayoutAttrInterface from a tensor type.
+  using ConstructMemRefLayoutFn =
+      std::function<MemRefLayoutAttrInterface(TensorType t)>;
 
   BufferizationOptions();
 
@@ -364,6 +367,12 @@ struct BufferizationOptions {
   DefaultMemorySpaceFn defaultMemorySpaceFn =
       [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
 
+  /// Construction function used to determine the memref layout based on the
+  /// original tensor type. Can be used to specialize tensor encoding -> memref
+  /// layout conversion. By default, it is unset, making the layout construction
+  /// behavior depend on the place where it is used.
+  ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
   bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..74864bfd57e58 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
     auto memSpace = options.defaultMemorySpaceFn(tensorType);
     if (!memSpace.has_value())
       return emitError() << "could not infer memory space";
+    MemRefLayoutAttrInterface layout = {};
+    if (options.constructMemRefLayoutFn)
+        layout = options.constructMemRefLayoutFn(tensorType);
 
     return cast<BufferLikeType>(
-        getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+        getMemRefType(tensorType, options, layout, *memSpace));
   }
 
   mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
+  // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+  // explicitly specified by the user. Otherwise, the default behavior is to
+  // return a fully dynamic layout map which is the opposite of the default
+  // behavior of this function.
+  if (options.constructMemRefLayoutFn) {
+    return cast<TensorLikeType>(getType()).getBufferType(
+        options, [&]() { return emitError(); });
+  }
+
   return cast<BufferLikeType>(
       getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRTransformsTests
   Canonicalizer.cpp
   DialectConversion.cpp
+  OneShotBufferization.cpp
 )
 mlir_target_link_libraries(MLIRTransformsTests
   PRIVATE
   MLIRParser
-  MLIRTransforms)
+  MLIRTransforms
+  MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  static TestTensorAttr fromStringAttr(StringAttr attr) {
+    return mlir::dyn_cast<TestTensorAttr>(attr);
+  }
+};
+
+class TestTensorEncodingVerifier final
+    : public mlir::VerifiableTensorEncoding::ExternalModel<
+          TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  mlir::LogicalResult verifyEncoding(
+      mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+      mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestTensorAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown Tensor enconding: " << attr;
+  }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  mlir::AffineMap getAffineMap() const {
+    return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+  }
+};
+
+class TestMemRefAttrLayout final
+    : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+          TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  bool isIdentity(mlir::Attribute) const { return true; }
+  mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+    return cast<TestMemRefAttr>(attr).getAffineMap();
+  }
+  mlir::LogicalResult
+  verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+               mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestMemRefAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown MemRef layout: " << attr;
+  }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+  MLIRContext context;
+  context.getOrLoadDialect<BuiltinDialect>();
+  context.getOrLoadDialect<func::FuncDialect>();
+  context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+  DialectRegistry registry;
+  registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+    TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+    TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+  });
+  bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+      registry);
+  context.appendDialectRegistry(registry);
+
+  const char *const code = R"mlir(
+    func.func @foo(%t: tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello"> {
+      return %t : tensor<42xf32, "hello">
+    }
+
+    func.func @bar(%t1: tensor<42xf32, "hello">)
+        -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+      %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello">
+
+      %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+      return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+    }
+  )mlir";
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+  ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+  bufferization::OneShotBufferizationOptions options{};
+  options.bufferizeFunctionBoundaries = true;
+  options.constructMemRefLayoutFn =
+      [](TensorType tensor) -> MemRefLayoutAttrInterface {
+    assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+    auto tensorType = cast<RankedTensorType>(tensor);
+    if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+      return cast<MemRefLayoutAttrInterface>(
+          TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+    }
+    return {};
+  };
+  options.functionArgTypeConverterFn =
+      [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+          func::FuncOp, const bufferization::BufferizationOptions &) {
+        assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+        auto tensorType = cast<RankedTensorType>(tensor);
+        auto layout = options.constructMemRefLayoutFn(tensorType);
+        return cast<bufferization::BufferLikeType>(
+            MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+                            layout, memSpace));
+      };
+
+  bufferization::BufferizationState state;
+  ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+      module->getOperation(), options, state)));
+
+  const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+    if (auto memref = dyn_cast<MemRefType>(type)) {
+      if (auto layout = memref.getLayout();
+          isa_and_nonnull<TestMemRefAttr>(layout)) {
+        return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+      }
+    }
+    return false;
+  };
+
+  auto fooOp = *module->getOps<func::FuncOp>().begin();
+  ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+  auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+  ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace

@github-actions
Copy link

github-actions bot commented Sep 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@andrey-golubev
Copy link
Contributor Author

andrey-golubev commented Sep 29, 2025

There are a couple of things that I don't like and perhaps we can sort them out during review:

  • I don't like the test in C++, but it seems somewhat necessary, since I need to patch the dialect registry to use custom encoding / layout:
    https://github.com/llvm/llvm-project/pull/161166/files#diff-992dd79e0fd3edebb88acda5a05cf9b10aee6186e04ace2cb4b88f0adb196b44R95-R99
    • Can we update dialect registry within a pass (generally, this is not great but I don't have any good option thus far anyway)? Perhaps I could have a test bufferization pass instead that does this prior to calling bufferizeOp.
  • Having to patch AllocTensorOp. In general, I think the fact that it doesn't go through TensorLikeType::getBufferType() is an oversight. For the sake of the test, however, I'd rather use my own test operations (such as TestCreateTensorOp and alike) but Test dialect is not available in unit tests right now (I need to make a CMake dependency, and perhaps this also ruins the conceptual idea of these tests?)

return cast<TensorLikeType>(getType()).getBufferType(
options, [&]() { return emitError(); });
}

Copy link
Contributor Author

@andrey-golubev andrey-golubev Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: right now, this is pretty much a crutch. I'm thinking that a good design could be this:

  • default constructMemRefLayoutFn returns fully dynamic layout map
  • user can specify their own behavior by setting a different callable
  • questionable: AllocTensorOp::getBufferType() would copy bufferization options and set:
    • options.defaultMemorySpaceFn (see right above the added code for what happens)
    • options.constructMemRefLayoutFn to return nullptr layout

Additionally, I think we should revisit TensorType::getBufferType(): instead of calling unknownTypeConversionFn we can just utilize constructMemRefLayoutFn introduced by this patch, rendering unknown type conversion unnecessary? (can delete that option completely). And thus making the ::getBufferType() look something like:

// for RankedTensorType
return MemRefType::get(
  tensor.getShape(), // as before
  tensor.getElementType(), // as before
  options.constructMemRefLayoutFn(tensor), // this is "new" and universal
  options.defaultMemorySpaceFn(tensor) // as before
);

Perhaps then we could get rid of getMemRefType / getMemRefTypeWithFullyDynamicLayout / getMemRefTypeWithStaticIdentityLayout triplet (at least there's no longer any "difference" since layout specification is controlled via options).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrey-golubev andrey-golubev force-pushed the test_encoding_layout_bufferization branch from d1de425 to 94419f2 Compare September 29, 2025 10:44
@andrey-golubev andrey-golubev changed the title [mlir][bufferization] Convert tensor enconding -> memref layout [mlir][bufferization] Convert tensor enconding into memref layout Sep 29, 2025
/// original tensor type. Can be used to specialize tensor encoding -> memref
/// layout conversion. By default, it is unset, making the layout construction
/// behavior depend on the place where it is used.
ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this option useful? The bufferization processes operations top-to-bottom (defs before users). The result type of a bufferized op can typically be inferred from the operands. For most ops, it is not possible to put a custom layout map on the result. (The op would not verify anymore.)

E.g.: tensor.extract_slice -> memref.subview: based on the operands and their types, we can infer the result type.

E.g.: bufferization.alloc_tensor -> memref.alloc: the result buffer always has an identity layout map.

E.g.: tensor function argument: the bufferized type can be controlled via FunctionArgTypeConverterFn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, thanks for questioning actually! I thought originally that I would need this in places that rely on RankedTensorType::getBufferType() (thinking of bufferization.to_tensor conversion for example).

But: thus far, realistically, custom encoding -> custom layout only works for our own operations in downstream (e.g. we encode "layout" information differently from upstream mlir, for example, type provides whether it's NHWC or NCHW or some other "format").

E.g.: bufferization.alloc_tensor -> memref.alloc: the result buffer always has an identity layout map.

Not in our case, but thus far we don't rely on bufferization.alloc_tensor either I think - we have our own allocation for own types, and use something from MLIR for memrefs. I guess I just have to check how that case works - which is probably a test.

I think at this point I don't have a solid example of why custom layout conversion would be necessary, so I'll drop this one and try to mimic what we have in our codebase from standard MLIR. Perhaps it would be enough to just use current infrastructure.

Support custom types (4/N): allow user-specified bufferization of tensor
encoding into memref layout.

Both tensor encoding and memref layout could be user-specified
attributes to store arbitrary information. It is often the case that
this information has to be preserved during tensor -> memref
bufferization. Thus, provide an option function to create memref
layout during TensorType::getBufferType() execution.

As a drive by, update AllocTensorOp::getBufferType() to work via
TensorLikeType::getBufferType() when memref layout is user-specified.
Support custom types (4/N): test that it is possible to customize memref
layout specification for custom operations and function boundaries.

This is purely a test setup (no API modifications) to ensure users are
able to pass information from tensors to memrefs within bufferization
process. To achieve this, a test pass is required (since bufferization
options have to be set manually). As there is already a
--test-one-shot-module-bufferize pass present, it is extended for the
purpose.
@andrey-golubev andrey-golubev force-pushed the test_encoding_layout_bufferization branch from 94419f2 to eb51e55 Compare October 6, 2025 12:41
@andrey-golubev andrey-golubev changed the title [mlir][bufferization] Convert tensor enconding into memref layout [mlir][bufferization] Test tensor encoding -> memref layout conversion Oct 6, 2025
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}>
Copy link
Contributor Author

@andrey-golubev andrey-golubev Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i don't know why, but somehow running the test locally for me failed without this fix. given that mlir-opt produces memref<?xf32> for me, I'm not sure how this works in main right now :|

let dependentDialects = ["::mlir::DLTIDialect"];
let dependentDialects = [
"::mlir::DLTIDialect",
"::mlir::bufferization::BufferizationDialect"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems required, otherwise some verifyInvariantsImpl checks failed internally when checking if a ranked tensor is a tensor-like type.

bufferization::TensorLikeType tensorLike) {
auto buffer =
*tensorLike.getBufferType(options, [&]() { return op->emitError(); });
if (auto memref = dyn_cast<MemRefType>(buffer)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: if we'd have a option callback that provides customizable layout inference, this branch could be avoided. instead, the one-shot bufferization options could be configured and this whole thing becomes just return TensorLike::getBufferType().

@andrey-golubev
Copy link
Contributor Author

@matthias-springer I've redone the tests completely (luckily, I don't need to test in unit C++ but can do everything in LIT) and also dropped the API update. This is effectively "new" PR but I repurposed the old one.

#161166 (comment) shows where I would rely on memref layout inference function in bufferization options (that is configured by a pass). But at this point it seems to be a "convenience" - since custom operations provide own bufferization anyway.

Anyhow, since we already have RankedTensorType::getBufferType() and function to get default memory space, maybe it makes sense to provide default layout inference.

@andrey-golubev
Copy link
Contributor Author

@Evanyl I see that you've updated module-level bufferization. As I'm extending your test pass and tests, feel free to take a look and drop a review.

@andrey-golubev
Copy link
Contributor Author

Gentle ping.

const auto bufferizedOutType = test::TestMemrefType::get(
getContext(), outType.getShape(), outType.getElementType(), nullptr);
const auto bufferizedOutType =
convertTensorToBuffer(getOperation(), options, outType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to call getBufferType here, so that the two function cannot get out-of-sync.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean Op::getBufferType? note that this is another op from the one below (dummy vs create) where i don't overwrite ::getBufferType(). the "create tensor" one already does what you suggest here.

@andrey-golubev andrey-golubev merged commit 3141bde into llvm:main Oct 15, 2025
9 checks passed
@andrey-golubev andrey-golubev deleted the test_encoding_layout_bufferization branch October 15, 2025 10:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants