Skip to content

Commit 3141bde

Browse files
[mlir][bufferization] Test tensor encoding -> memref layout conversion (llvm#161166)
Support custom types (4/N): test that it is possible to customize memref layout specification for custom operations and function boundaries. This is purely a test setup (no API modifications) to ensure users are able to pass information from tensors to memrefs within bufferization process. To achieve this, a test pass is required (since bufferization options have to be set manually). As there is already a --test-one-shot-module-bufferize pass present, it is extended for the purpose.
1 parent bf64316 commit 3141bde

File tree

9 files changed

+150
-15
lines changed

9 files changed

+150
-15
lines changed

mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,46 @@
2424
// CHECK-NOT: copy
2525
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
2626
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
27-
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
27+
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}>
2828
return %1, %0 : f32, tensor<?xf32>
2929
}
3030
"test.finish" () : () -> ()
3131
}) : () -> ()
3232

33+
// -----
3334

35+
#enc1 = #test.tensor_encoding<"hello">
36+
#enc2 = #test.tensor_encoding<"not hello">
37+
38+
"test.symbol_scope_isolated"() ({
39+
// CHECK: func @inner_func(
40+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
41+
// CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">>
42+
func.func @inner_func(%t: tensor<?xf32, #enc1>)
43+
-> tensor<?xf32, #enc1> {
44+
// CHECK: return %[[arg0]]
45+
return %t : tensor<?xf32, #enc1>
46+
}
47+
48+
// CHECK: func @outer_func(
49+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
50+
// CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>,
51+
// CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>)
52+
func.func @outer_func(%t0: tensor<?xf32, #enc1>)
53+
-> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
54+
// CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
55+
%0 = call @inner_func(%t0)
56+
: (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
57+
58+
// CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
59+
// CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">>
60+
%local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
61+
// CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
62+
%1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
63+
-> tensor<?xf32, #enc2>
64+
65+
// CHECK: return %[[call]], %[[dummy]]
66+
return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
67+
}
68+
"test.finish" () : () -> ()
69+
}) : () -> ()

mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,25 @@
1111
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1212
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1415
#include "mlir/Pass/Pass.h"
1516

17+
#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr
18+
#include "TestDialect.h"
19+
1620
using namespace mlir;
1721

1822
namespace {
23+
MemRefLayoutAttrInterface
24+
getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
25+
if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>(
26+
tensorType.getEncoding())) {
27+
return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
28+
tensorType.getContext(), encoding.getDummy()));
29+
}
30+
return {};
31+
}
32+
1933
struct TestOneShotModuleBufferizePass
2034
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
2135
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
@@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass
2539
: PassWrapper(pass) {}
2640

2741
void getDependentDialects(DialectRegistry &registry) const override {
42+
registry.insert<test::TestDialect>();
2843
registry.insert<bufferization::BufferizationDialect>();
2944
}
3045
StringRef getArgument() const final {
@@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass
4156
bufferization::OneShotBufferizationOptions opt;
4257

4358
opt.bufferizeFunctionBoundaries = true;
59+
opt.functionArgTypeConverterFn =
60+
[&](bufferization::TensorLikeType tensor, Attribute memSpace,
61+
func::FuncOp, const bufferization::BufferizationOptions &) {
62+
assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
63+
auto tensorType = cast<RankedTensorType>(tensor);
64+
auto layout = getMemRefLayoutForTensorEncoding(tensorType);
65+
return cast<bufferization::BufferLikeType>(
66+
MemRefType::get(tensorType.getShape(),
67+
tensorType.getElementType(), layout, memSpace));
68+
};
69+
4470
bufferization::BufferizationState bufferizationState;
4571

4672
if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
2222
include "mlir/IR/BuiltinAttributeInterfaces.td"
2323
include "mlir/IR/EnumAttr.td"
2424
include "mlir/IR/OpAsmInterface.td"
25+
include "mlir/IR/TensorEncoding.td"
2526

2627
// All of the attributes will extend this class.
2728
class Test_Attr<string name, list<Trait> traits = []>
@@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> {
439440
let hasStorageCustomConstructor = 1;
440441
}
441442

443+
def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding",
444+
[DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> {
445+
let mnemonic = "tensor_encoding";
446+
447+
let parameters = (ins "mlir::StringAttr":$dummy);
448+
let assemblyFormat = "`<` $dummy `>`";
449+
}
450+
451+
def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
452+
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
453+
let mnemonic = "memref_layout";
454+
455+
let parameters = (ins "mlir::StringAttr":$dummy);
456+
let assemblyFormat = "`<` $dummy `>`";
457+
}
458+
442459
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct(
541541
return nullptr;
542542
}
543543

544+
//===----------------------------------------------------------------------===//
545+
// TestTensorEncodingAttr
546+
//===----------------------------------------------------------------------===//
547+
548+
::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding(
549+
mlir::ArrayRef<int64_t> shape, mlir::Type elementType,
550+
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
551+
return mlir::success();
552+
}
553+
554+
//===----------------------------------------------------------------------===//
555+
// TestMemRefLayoutAttr
556+
//===----------------------------------------------------------------------===//
557+
558+
mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const {
559+
return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
560+
}
561+
544562
//===----------------------------------------------------------------------===//
545563
// TestDialect
546564
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestAttributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/Dialect.h"
2525
#include "mlir/IR/DialectImplementation.h"
2626
#include "mlir/IR/DialectResourceBlobManager.h"
27+
#include "mlir/IR/TensorEncoding.h"
2728

2829
// generated files require above includes to come first
2930
#include "TestAttrInterfaces.h.inc"

mlir/test/lib/Dialect/Test/TestDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "TestInterfaces.h"
1919
#include "TestTypes.h"
2020
#include "mlir/Bytecode/BytecodeImplementation.h"
21+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2122
#include "mlir/Dialect/DLTI/DLTI.h"
2223
#include "mlir/Dialect/DLTI/Traits.h"
2324
#include "mlir/Dialect/Func/IR/FuncOps.h"

mlir/test/lib/Dialect/Test/TestDialect.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect {
2424
let useDefaultTypePrinterParser = 0;
2525
let useDefaultAttributePrinterParser = 1;
2626
let isExtensible = 1;
27-
let dependentDialects = ["::mlir::DLTIDialect"];
27+
let dependentDialects = [
28+
"::mlir::DLTIDialect",
29+
"::mlir::bufferization::BufferizationDialect"
30+
];
2831
let discardableAttrs = (ins
2932
"mlir::IntegerAttr":$discardable_attr_key,
3033
"SimpleAAttr":$other_discardable_attr_key

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete(
14251425
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
14261426
}
14271427

1428+
namespace {
1429+
/// Returns test dialect's memref layout for test dialect's tensor encoding when
1430+
/// applicable.
1431+
MemRefLayoutAttrInterface
1432+
getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
1433+
if (auto encoding =
1434+
dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
1435+
return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
1436+
tensorType.getContext(), encoding.getDummy()));
1437+
}
1438+
return {};
1439+
}
1440+
1441+
/// Auxiliary bufferization function for test and builtin tensors.
1442+
bufferization::BufferLikeType
1443+
convertTensorToBuffer(mlir::Operation *op,
1444+
const bufferization::BufferizationOptions &options,
1445+
bufferization::TensorLikeType tensorLike) {
1446+
auto buffer =
1447+
*tensorLike.getBufferType(options, [&]() { return op->emitError(); });
1448+
if (auto memref = dyn_cast<MemRefType>(buffer)) {
1449+
// Note: For the sake of testing, we want to ensure that encoding -> layout
1450+
// bufferization happens. This is currently achieved manually.
1451+
auto layout =
1452+
getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
1453+
return cast<bufferization::BufferLikeType>(
1454+
MemRefType::get(memref.getShape(), memref.getElementType(), layout,
1455+
memref.getMemorySpace()));
1456+
}
1457+
return buffer;
1458+
}
1459+
} // namespace
1460+
14281461
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
14291462
::mlir::RewriterBase &rewriter,
14301463
const ::mlir::bufferization::BufferizationOptions &options,
@@ -1435,8 +1468,8 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
14351468
return failure();
14361469

14371470
const auto outType = getOutput().getType();
1438-
const auto bufferizedOutType = test::TestMemrefType::get(
1439-
getContext(), outType.getShape(), outType.getElementType(), nullptr);
1471+
const auto bufferizedOutType =
1472+
convertTensorToBuffer(getOperation(), options, outType);
14401473
// replace op with memref analogy
14411474
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
14421475
rewriter, getLoc(), bufferizedOutType, *buffer);
@@ -1470,13 +1503,12 @@ ::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
14701503

14711504
mlir::FailureOr<mlir::bufferization::BufferLikeType>
14721505
test::TestCreateTensorOp::getBufferType(
1473-
mlir::Value value, const mlir::bufferization::BufferizationOptions &,
1506+
mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
14741507
const mlir::bufferization::BufferizationState &,
14751508
llvm::SmallVector<::mlir::Value> &) {
1476-
const auto type = dyn_cast<test::TestTensorType>(value.getType());
1509+
const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
14771510
if (type == nullptr)
14781511
return failure();
14791512

1480-
return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
1481-
getContext(), type.getShape(), type.getElementType(), nullptr));
1513+
return convertTensorToBuffer(getOperation(), options, type);
14821514
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td"
3232
include "mlir/Interfaces/SideEffectInterfaces.td"
3333
include "mlir/Interfaces/ValueBoundsOpInterface.td"
3434
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
35+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
3536

3637
// Include the attribute definitions.
3738
include "TestAttrDefs.td"
@@ -2335,7 +2336,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
23352336
}
23362337

23372338
//===----------------------------------------------------------------------===//
2338-
// Copy Operation Test
2339+
// Copy Operation Test
23392340
//===----------------------------------------------------------------------===//
23402341

23412342
def CopyOp : TEST_Op<"copy", []> {
@@ -3676,10 +3677,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
36763677
["bufferize", "bufferizesToMemoryRead",
36773678
"bufferizesToMemoryWrite", "getAliasingValues"]>]> {
36783679
let arguments = (ins
3679-
Arg<TestTensorType>:$input
3680+
Arg<Bufferization_TensorLikeTypeInterface>:$input
36803681
);
36813682
let results = (outs
3682-
Arg<TestTensorType>:$output
3683+
Arg<Bufferization_TensorLikeTypeInterface>:$output
36833684
);
36843685

36853686
let extraClassDefinition = [{
@@ -3701,10 +3702,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
37013702

37023703
def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
37033704
let arguments = (ins
3704-
Arg<TestMemrefType>:$input
3705+
Arg<Bufferization_BufferLikeTypeInterface>:$input
37053706
);
37063707
let results = (outs
3707-
Arg<TestMemrefType>:$output
3708+
Arg<Bufferization_BufferLikeTypeInterface>:$output
37083709
);
37093710
}
37103711

@@ -3714,7 +3715,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
37143715
"bufferizesToMemoryWrite", "getAliasingValues",
37153716
"bufferizesToAllocation"]>]> {
37163717
let arguments = (ins);
3717-
let results = (outs Arg<TestTensorType>:$output);
3718+
let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
37183719
let extraClassDefinition = [{
37193720
bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
37203721
const ::mlir::bufferization::AnalysisState&) {
@@ -3738,7 +3739,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
37383739

37393740
def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
37403741
let arguments = (ins);
3741-
let results = (outs Arg<TestMemrefType>:$output);
3742+
let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output);
37423743
}
37433744

37443745
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)