Skip to content

Commit ece4805

Browse files
[mlir][bufferization] Convert tensor enconding -> memref layout
Support custom types (4/N): allow user-specified bufferization of tensor encoding into memref layout. Both tensor encoding and memref layout could be user-specified attributes to store arbitrary information. It is often the case that this information has to be preserved during tensor -> memref bufferization. Thus, provide an option function to create memref layout during TensorType::getBufferType() execution. As a drive by, update AllocTensorOp::getBufferType() to work via TensorLikeType::getBufferType() when memref layout is user-specified.
1 parent 5e07093 commit ece4805

File tree

5 files changed

+197
-2
lines changed

5 files changed

+197
-2
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ struct BufferizationOptions {
272272
// Produce a MemorySpace attribute from a tensor type
273273
using DefaultMemorySpaceFn =
274274
std::function<std::optional<Attribute>(TensorType t)>;
275+
/// Construct a MemRefLayoutAttrInterface from a tensor type.
276+
using ConstructMemRefLayoutFn =
277+
std::function<MemRefLayoutAttrInterface(TensorType t)>;
275278

276279
BufferizationOptions();
277280

@@ -364,6 +367,12 @@ struct BufferizationOptions {
364367
DefaultMemorySpaceFn defaultMemorySpaceFn =
365368
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
366369

370+
/// Construction function used to determine the memref layout based on the
371+
/// original tensor type. Can be used to specialize tensor encoding -> memref
372+
/// layout conversion. By default, it is unset, making the layout construction
373+
/// behavior depend on the place where it is used.
374+
ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
375+
367376
/// If set to `true`, the analysis is skipped. A buffer is copied before every
368377
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
369378
bool copyBeforeWrite = false;

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
6565
auto memSpace = options.defaultMemorySpaceFn(tensorType);
6666
if (!memSpace.has_value())
6767
return emitError() << "could not infer memory space";
68+
MemRefLayoutAttrInterface layout = {};
69+
if (options.constructMemRefLayoutFn)
70+
layout = options.constructMemRefLayoutFn(tensorType);
6871

6972
return cast<BufferLikeType>(
70-
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
73+
getMemRefType(tensorType, options, layout, *memSpace));
7174
}
7275

7376
mlir::LogicalResult verifyCompatibleBufferType(

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
244244
return getOperation()->emitError("could not infer memory space");
245245
}
246246

247+
// Note: Only rely on TensorLikeType::getBufferType() if memref layout is
248+
// explicitly specified by the user. Otherwise, the default behavior is to
249+
// return a fully dynamic layout map which is the opposite of the default
250+
// behavior of this function.
251+
if (options.constructMemRefLayoutFn) {
252+
return cast<TensorLikeType>(getType()).getBufferType(
253+
options, [&]() { return emitError(); });
254+
}
255+
247256
return cast<BufferLikeType>(
248257
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
249258
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
add_mlir_unittest(MLIRTransformsTests
22
Canonicalizer.cpp
33
DialectConversion.cpp
4+
OneShotBufferization.cpp
45
)
56
mlir_target_link_libraries(MLIRTransformsTests
67
PRIVATE
78
MLIRParser
8-
MLIRTransforms)
9+
MLIRTransforms
10+
MLIRBufferizationTransforms
11+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10+
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
11+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/IR/BuiltinDialect.h"
15+
#include "mlir/IR/TensorEncoding.h"
16+
#include "mlir/Parser/Parser.h"
17+
#include "mlir/Pass/PassManager.h"
18+
19+
#include "gtest/gtest.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
25+
struct TestTensorAttr : public StringAttr {
26+
using mlir::StringAttr::StringAttr;
27+
28+
static bool classof(mlir::Attribute attr) {
29+
return mlir::isa<mlir::StringAttr>(attr);
30+
}
31+
32+
static TestTensorAttr fromStringAttr(StringAttr attr) {
33+
return mlir::dyn_cast<TestTensorAttr>(attr);
34+
}
35+
};
36+
37+
class TestTensorEncodingVerifier final
38+
: public mlir::VerifiableTensorEncoding::ExternalModel<
39+
TestTensorEncodingVerifier, TestTensorAttr> {
40+
public:
41+
using ConcreteEntity = mlir::StringAttr;
42+
43+
mlir::LogicalResult verifyEncoding(
44+
mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
45+
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
46+
std::ignore = shape;
47+
48+
if (mlir::isa<TestTensorAttr>(attr)) {
49+
return mlir::success();
50+
}
51+
return emitError() << "Unknown Tensor enconding: " << attr;
52+
}
53+
};
54+
55+
struct TestMemRefAttr : public mlir::StringAttr {
56+
using mlir::StringAttr::StringAttr;
57+
58+
static bool classof(mlir::Attribute attr) {
59+
return mlir::isa<mlir::StringAttr>(attr);
60+
}
61+
62+
mlir::AffineMap getAffineMap() const {
63+
return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
64+
}
65+
};
66+
67+
class TestMemRefAttrLayout final
68+
: public mlir::MemRefLayoutAttrInterface::ExternalModel<
69+
TestMemRefAttrLayout, TestMemRefAttr> {
70+
public:
71+
using ConcreteEntity = mlir::StringAttr;
72+
73+
bool isIdentity(mlir::Attribute) const { return true; }
74+
mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
75+
return cast<TestMemRefAttr>(attr).getAffineMap();
76+
}
77+
mlir::LogicalResult
78+
verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
79+
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
80+
std::ignore = shape;
81+
82+
if (mlir::isa<TestMemRefAttr>(attr)) {
83+
return mlir::success();
84+
}
85+
return emitError() << "Unknown MemRef layout: " << attr;
86+
}
87+
};
88+
89+
TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
90+
MLIRContext context;
91+
context.getOrLoadDialect<BuiltinDialect>();
92+
context.getOrLoadDialect<func::FuncDialect>();
93+
context.getOrLoadDialect<bufferization::BufferizationDialect>();
94+
95+
DialectRegistry registry;
96+
registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
97+
TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
98+
TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
99+
});
100+
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
101+
registry);
102+
context.appendDialectRegistry(registry);
103+
104+
const char *const code = R"mlir(
105+
func.func @foo(%t: tensor<42xf32, "hello">)
106+
-> tensor<42xf32, "hello"> {
107+
return %t : tensor<42xf32, "hello">
108+
}
109+
110+
func.func @bar(%t1: tensor<42xf32, "hello">)
111+
-> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
112+
%out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
113+
-> tensor<42xf32, "hello">
114+
115+
%out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
116+
117+
return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
118+
}
119+
)mlir";
120+
121+
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
122+
ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
123+
124+
bufferization::OneShotBufferizationOptions options{};
125+
options.bufferizeFunctionBoundaries = true;
126+
options.constructMemRefLayoutFn =
127+
[](TensorType tensor) -> MemRefLayoutAttrInterface {
128+
assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
129+
auto tensorType = cast<RankedTensorType>(tensor);
130+
if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
131+
return cast<MemRefLayoutAttrInterface>(
132+
TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
133+
}
134+
return {};
135+
};
136+
options.functionArgTypeConverterFn =
137+
[&](bufferization::TensorLikeType tensor, Attribute memSpace,
138+
func::FuncOp, const bufferization::BufferizationOptions &) {
139+
assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
140+
auto tensorType = cast<RankedTensorType>(tensor);
141+
auto layout = options.constructMemRefLayoutFn(tensorType);
142+
return cast<bufferization::BufferLikeType>(
143+
MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
144+
layout, memSpace));
145+
};
146+
147+
bufferization::BufferizationState state;
148+
ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
149+
module->getOperation(), options, state)));
150+
151+
const auto checkType = [](Type type, StringRef expectedLayoutValue) {
152+
if (auto memref = dyn_cast<MemRefType>(type)) {
153+
if (auto layout = memref.getLayout();
154+
isa_and_nonnull<TestMemRefAttr>(layout)) {
155+
return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
156+
}
157+
}
158+
return false;
159+
};
160+
161+
auto fooOp = *module->getOps<func::FuncOp>().begin();
162+
ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
163+
ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
164+
165+
auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
166+
ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
167+
ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
168+
ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
169+
}
170+
171+
} // end anonymous namespace

0 commit comments

Comments
 (0)