Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

//===----------------------------------------------------------------------===//
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"

#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for type interfaces used in Bufferization.
//
//===----------------------------------------------------------------------===//

#ifndef BUFFERIZATION_TYPE_INTERFACES
#define BUFFERIZATION_TYPE_INTERFACES

include "mlir/IR/OpBase.td"

def Bufferization_TensorLikeTypeInterface
: TypeInterface<"TensorLikeType"> {
let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that this type is a tensor type (similarly to a MLIR builtin
tensor) for bufferization purposes.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}

def Bufferization_MemRefLikeTypeInterface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ask for one more change? The term used in the bufferization is "buffer", not memref. Can we change this to BufferLikeTypeInterface? E.g., we have getBufferType in BufferizableOpInterface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have thought of this as well but wasn't sure it's better than memref. If you prefer "buffer", sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

: TypeInterface<"MemRefLikeType"> {
let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that this type is a memref type (similarly to a MLIR builtin
memref) for bufferization purposes.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}

#endif // BUFFERIZATION_TYPE_INTERFACES
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place",
"Number of out-of-place tensor OpOperands">,
];

let dependentDialects = [
"bufferization::BufferizationDialect", "memref::MemRefDialect"
];
}

def PromoteBuffersToStackPass
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"

Expand Down Expand Up @@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
return true;
}
};

template <typename Tensor>
struct BuiltinTensorExternalModel
: TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
Tensor> {};

template <typename MemRef>
struct BuiltinMemRefExternalModel
: MemRefLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
MemRef> {};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -63,6 +75,20 @@ void mlir::bufferization::BufferizationDialect::initialize() {
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
>();
addInterfaces<BufferizationInlinerInterface>();

// Note: Unlike with other external models, declaring bufferization's
// "promised interfaces" in builtins for TensorLike and MemRefLike type
// interfaces is not possible (due to builtins being independent of
// bufferization). Thus, the compromise is to attach these interfaces directly
// during dialect initialization.
RankedTensorType::attachInterface<
BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
UnrankedTensorType::attachInterface<
BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
*getContext());
UnrankedMemRefType::attachInterface<
BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
}

LogicalResult BufferizationDialect::verifyRegionArgAttribute(
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ struct OneShotBufferizePass
OneShotBufferizePass> {
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
}

void runOnOperation() override {
OneShotBufferizationOptions opt;
if (!options) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: mlir-opt %s -test-tensorlike-memreflike -split-input-file | FileCheck %s

// CHECK: func.func @builtin_unranked
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_memref_like"}}
func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>)
{
%0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32>
return %0 : memref<*xf32>
}

// -----

// CHECK: func.func @builtin_ranked
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_memref_like"}}
func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>)
{
%0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32>
return %0 : memref<42xf32>
}

// -----

// CHECK: func.func @custom_tensor
// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}}
func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> ()
{
return
}

// -----

// CHECK: func.func @custom_memref
// CHECK-SAME: {found = {operand_0 = "is_memref_like"}}
func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> ()
{
return
}
8 changes: 8 additions & 0 deletions mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRBufferizationTestPasses
TestTensorCopyInsertion.cpp
TestTensorLikeAndMemRefLike.cpp

EXCLUDE_FROM_LIBMLIR
)
Expand All @@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC
MLIRBufferizationTransforms
MLIRIR
MLIRPass
MLIRTestDialect
)

target_include_directories(MLIRBufferizationTestPasses
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===- TestTensorLikeAndMemRefLike.cpp - Bufferization Test -----*- c++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"

#include <string>

using namespace mlir;

namespace {
std::string getImplementationStatus(Type type) {
if (isa<bufferization::TensorLikeType>(type)) {
return "is_tensor_like";
}
if (isa<bufferization::MemRefLikeType>(type)) {
return "is_memref_like";
}
return {};
}

DictionaryAttr findAllImplementeesOfTensorOrMemRefLike(func::FuncOp funcOp) {
llvm::SmallVector<NamedAttribute> attributes;

const auto funcType = funcOp.getFunctionType();
for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) {
const auto status = getImplementationStatus(inputType);
if (status.empty()) {
continue;
}

attributes.push_back(
NamedAttribute(StringAttr::get(funcOp.getContext(),
"operand_" + std::to_string(index)),
StringAttr::get(funcOp.getContext(), status)));
}

for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) {
const auto status = getImplementationStatus(resultType);
if (status.empty()) {
continue;
}

attributes.push_back(NamedAttribute(
StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)),
StringAttr::get(funcOp.getContext(), status)));
}

return mlir::DictionaryAttr::get(funcOp.getContext(), attributes);
}

/// This pass tests whether specified types implement TensorLike and (or)
/// MemRefLike type interfaces defined in bufferization.
///
/// The pass analyses operation signature. When the aforementioned interface
/// implementation found, an attribute is added to the operation, signifying the
/// associated operand / result.
struct TestTensorLikeAndMemRefLikePass
: public PassWrapper<TestTensorLikeAndMemRefLikePass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndMemRefLikePass)

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
}
StringRef getArgument() const final { return "test-tensorlike-memreflike"; }
StringRef getDescription() const final {
return "Module pass to test custom types that implement TensorLike / "
"MemRefLike interfaces";
}

void runOnOperation() override {
auto op = getOperation();

op.walk([](func::FuncOp funcOp) {
const auto dict = findAllImplementeesOfTensorOrMemRefLike(funcOp);
if (!dict.empty()) {
funcOp->setAttr("found", dict);
}
});
}
};
} // namespace

namespace mlir::test {
void registerTestTensorLikeAndMemRefLikePass() {
PassRegistration<TestTensorLikeAndMemRefLikePass>();
}
} // namespace mlir::test
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRTransformUtils
MLIRTransforms
MLIRValueBoundsOpInterface
MLIRBufferizationDialect
)

add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"

// All of the types will extend this class.
class Test_Type<string name, list<Trait> traits = []>
Expand Down Expand Up @@ -403,4 +404,49 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
let mnemonic = "op_asm_type_interface";
}

def TestTensorType : Test_Type<"TestTensor",
[Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> {
let mnemonic = "test_tensor";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";

let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestTensorType::get(
getContext(), shape.value_or(getShape()), elementType);
}
}];
}

def TestMemrefType : Test_Type<"TestMemref",
[Bufferization_MemRefLikeTypeInterface, ShapedTypeInterface]> {
let mnemonic = "test_memref";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType,
DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";

let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestMemrefType::get(
getContext(), shape.value_or(getShape()), elementType, getMemSpace());
}
}];
}

#endif // TEST_TYPEDEFS
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <tuple>

#include "TestTraits.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ void registerTestSPIRVCPURunnerPipeline();
void registerTestSPIRVFuncSignatureConversion();
void registerTestSPIRVVectorUnrolling();
void registerTestTensorCopyInsertionPass();
void registerTestTensorLikeAndMemRefLikePass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
Expand Down Expand Up @@ -291,6 +292,7 @@ void registerTestPasses() {
mlir::test::registerTestSPIRVFuncSignatureConversion();
mlir::test::registerTestSPIRVVectorUnrolling();
mlir::test::registerTestTensorCopyInsertionPass();
mlir::test::registerTestTensorLikeAndMemRefLikePass();
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();
Expand Down