|
| 1 | +//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 10 | +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" |
| 11 | +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 12 | +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
| 13 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | +#include "mlir/IR/BuiltinDialect.h" |
| 15 | +#include "mlir/IR/TensorEncoding.h" |
| 16 | +#include "mlir/Parser/Parser.h" |
| 17 | +#include "mlir/Pass/PassManager.h" |
| 18 | + |
| 19 | +#include "gtest/gtest.h" |
| 20 | + |
| 21 | +using namespace mlir; |
| 22 | + |
| 23 | +namespace { |
| 24 | + |
| 25 | +struct TestTensorAttr : public StringAttr { |
| 26 | + using mlir::StringAttr::StringAttr; |
| 27 | + |
| 28 | + static bool classof(mlir::Attribute attr) { |
| 29 | + return mlir::isa<mlir::StringAttr>(attr); |
| 30 | + } |
| 31 | + |
| 32 | + static TestTensorAttr fromStringAttr(StringAttr attr) { |
| 33 | + return mlir::dyn_cast<TestTensorAttr>(attr); |
| 34 | + } |
| 35 | +}; |
| 36 | + |
| 37 | +class TestTensorEncodingVerifier final |
| 38 | + : public mlir::VerifiableTensorEncoding::ExternalModel< |
| 39 | + TestTensorEncodingVerifier, TestTensorAttr> { |
| 40 | +public: |
| 41 | + using ConcreteEntity = mlir::StringAttr; |
| 42 | + |
| 43 | + mlir::LogicalResult verifyEncoding( |
| 44 | + mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type, |
| 45 | + mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const { |
| 46 | + std::ignore = shape; |
| 47 | + |
| 48 | + if (mlir::isa<TestTensorAttr>(attr)) { |
| 49 | + return mlir::success(); |
| 50 | + } |
| 51 | + return emitError() << "Unknown Tensor enconding: " << attr; |
| 52 | + } |
| 53 | +}; |
| 54 | + |
| 55 | +struct TestMemRefAttr : public mlir::StringAttr { |
| 56 | + using mlir::StringAttr::StringAttr; |
| 57 | + |
| 58 | + static bool classof(mlir::Attribute attr) { |
| 59 | + return mlir::isa<mlir::StringAttr>(attr); |
| 60 | + } |
| 61 | + |
| 62 | + mlir::AffineMap getAffineMap() const { |
| 63 | + return mlir::AffineMap::getMultiDimIdentityMap(1, getContext()); |
| 64 | + } |
| 65 | +}; |
| 66 | + |
| 67 | +class TestMemRefAttrLayout final |
| 68 | + : public mlir::MemRefLayoutAttrInterface::ExternalModel< |
| 69 | + TestMemRefAttrLayout, TestMemRefAttr> { |
| 70 | +public: |
| 71 | + using ConcreteEntity = mlir::StringAttr; |
| 72 | + |
| 73 | + bool isIdentity(mlir::Attribute) const { return true; } |
| 74 | + mlir::AffineMap getAffineMap(mlir::Attribute attr) const { |
| 75 | + return cast<TestMemRefAttr>(attr).getAffineMap(); |
| 76 | + } |
| 77 | + mlir::LogicalResult |
| 78 | + verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, |
| 79 | + mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const { |
| 80 | + std::ignore = shape; |
| 81 | + |
| 82 | + if (mlir::isa<TestMemRefAttr>(attr)) { |
| 83 | + return mlir::success(); |
| 84 | + } |
| 85 | + return emitError() << "Unknown MemRef layout: " << attr; |
| 86 | + } |
| 87 | +}; |
| 88 | + |
| 89 | +TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) { |
| 90 | + MLIRContext context; |
| 91 | + context.getOrLoadDialect<BuiltinDialect>(); |
| 92 | + context.getOrLoadDialect<func::FuncDialect>(); |
| 93 | + context.getOrLoadDialect<bufferization::BufferizationDialect>(); |
| 94 | + |
| 95 | + DialectRegistry registry; |
| 96 | + registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) { |
| 97 | + TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx); |
| 98 | + TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx); |
| 99 | + }); |
| 100 | + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( |
| 101 | + registry); |
| 102 | + context.appendDialectRegistry(registry); |
| 103 | + |
| 104 | + const char *const code = R"mlir( |
| 105 | + func.func @foo(%t: tensor<42xf32, "hello">) |
| 106 | + -> tensor<42xf32, "hello"> { |
| 107 | + return %t : tensor<42xf32, "hello"> |
| 108 | + } |
| 109 | +
|
| 110 | + func.func @bar(%t1: tensor<42xf32, "hello">) |
| 111 | + -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) { |
| 112 | + %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">) |
| 113 | + -> tensor<42xf32, "hello"> |
| 114 | +
|
| 115 | + %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello"> |
| 116 | +
|
| 117 | + return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello"> |
| 118 | + } |
| 119 | + )mlir"; |
| 120 | + |
| 121 | + OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context); |
| 122 | + ASSERT_NE(module.get(), nullptr) << "parsing should be successful"; |
| 123 | + |
| 124 | + bufferization::OneShotBufferizationOptions options{}; |
| 125 | + options.bufferizeFunctionBoundaries = true; |
| 126 | + options.constructMemRefLayoutFn = |
| 127 | + [](TensorType tensor) -> MemRefLayoutAttrInterface { |
| 128 | + assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors"); |
| 129 | + auto tensorType = cast<RankedTensorType>(tensor); |
| 130 | + if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) { |
| 131 | + return cast<MemRefLayoutAttrInterface>( |
| 132 | + TestMemRefAttr::get(tensor.getContext(), encoding.strref())); |
| 133 | + } |
| 134 | + return {}; |
| 135 | + }; |
| 136 | + options.functionArgTypeConverterFn = |
| 137 | + [&](bufferization::TensorLikeType tensor, Attribute memSpace, |
| 138 | + func::FuncOp, const bufferization::BufferizationOptions &) { |
| 139 | + assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors"); |
| 140 | + auto tensorType = cast<RankedTensorType>(tensor); |
| 141 | + auto layout = options.constructMemRefLayoutFn(tensorType); |
| 142 | + return cast<bufferization::BufferLikeType>( |
| 143 | + MemRefType::get(tensorType.getShape(), tensorType.getElementType(), |
| 144 | + layout, memSpace)); |
| 145 | + }; |
| 146 | + |
| 147 | + bufferization::BufferizationState state; |
| 148 | + ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize( |
| 149 | + module->getOperation(), options, state))); |
| 150 | + |
| 151 | + const auto checkType = [](Type type, StringRef expectedLayoutValue) { |
| 152 | + if (auto memref = dyn_cast<MemRefType>(type)) { |
| 153 | + if (auto layout = memref.getLayout(); |
| 154 | + isa_and_nonnull<TestMemRefAttr>(layout)) { |
| 155 | + return cast<TestMemRefAttr>(layout) == expectedLayoutValue; |
| 156 | + } |
| 157 | + } |
| 158 | + return false; |
| 159 | + }; |
| 160 | + |
| 161 | + auto fooOp = *module->getOps<func::FuncOp>().begin(); |
| 162 | + ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello")); |
| 163 | + ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello")); |
| 164 | + |
| 165 | + auto barOp = *std::next(module->getOps<func::FuncOp>().begin()); |
| 166 | + ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello")); |
| 167 | + ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello")); |
| 168 | + ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello")); |
| 169 | +} |
| 170 | + |
| 171 | +} // end anonymous namespace |
0 commit comments