Skip to content

Commit be85978

Browse files
[mlir][bufferization] Support custom types (1/N)
Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.
1 parent 66580f7 commit be85978

File tree

19 files changed

+451
-108
lines changed

19 files changed

+451
-108
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <optional>
1818

1919
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
20+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2021

2122
namespace mlir {
2223
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
615616
/// IR, this function can be used.
616617
///
617618
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
618-
FailureOr<BaseMemRefType> getBufferType(Value value,
619+
FailureOr<BufferLikeType> getBufferType(Value value,
619620
const BufferizationOptions &options,
620621
const BufferizationState &state);
621622

@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
629630
/// IR, this function can be used.
630631
///
631632
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
632-
FailureOr<BaseMemRefType> getBufferType(Value value,
633+
FailureOr<BufferLikeType> getBufferType(Value value,
633634
const BufferizationOptions &options,
634635
const BufferizationState &state,
635636
SmallVector<Value> &invocationStack);
@@ -739,6 +740,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
739740
/// This is the default implementation of
740741
/// BufferizableOpInterface::hasTensorSemantics
741742
bool defaultHasTensorSemantics(Operation *op);
743+
744+
/// This is a helper function used when buffer type is guaranteed to be memref.
745+
FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
746+
747+
/// This function is a free-standing helper that relies on
748+
/// bufferization::ConversionInterface to verify the types in tensor and buffer
749+
/// worlds match.
750+
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
751+
752+
/// This function is a free-standing helper that relies on
753+
/// bufferization::ConversionInterface to perform the conversion.
754+
Type getTensorFromBuffer(Type buffer);
742755
} // namespace detail
743756

744757
} // namespace bufferization
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
11+
12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
14+
#include "mlir/IR/DialectInterface.h"
15+
16+
namespace mlir {
17+
namespace bufferization {
18+
19+
/// This class defines a virtual interface for conversions between tensor-like
20+
/// and buffer-like types.
21+
struct ConversionDialectInterface
22+
: DialectInterface::Base<ConversionDialectInterface> {
23+
using Base::Base;
24+
25+
/// Hook to customize tensor-like -> buffer-like conversion within a given
26+
/// dialect. Returns a buffer-like type for the specific tensor-like type.
27+
virtual FailureOr<BufferLikeType> getBufferType(
28+
Value value, const BufferizationOptions &options,
29+
const BufferizationState &state,
30+
function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
31+
32+
/// Hook to customize type checking between tensor-like and buffer-like types.
33+
/// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
34+
/// `typesMatch(T, B)` must return true.
35+
virtual LogicalResult typesMatch(
36+
TensorLikeType tensor, BufferLikeType buffer,
37+
function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
38+
39+
/// Hook to customize buffer-like -> tensor-like conversion, which is the
40+
/// opposite of bufferization.
41+
virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
42+
};
43+
44+
/// Interface collection for conversion between tensor-like and buffer-like
45+
/// types, dispatches to a concrete interface implementation based on the
46+
/// dialect to which the given type belongs.
47+
struct ConversionInterface
48+
: DialectInterfaceCollection<ConversionDialectInterface> {
49+
using Base::Base;
50+
51+
/// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
52+
/// associated with the value type.
53+
FailureOr<BufferLikeType> getBufferType(
54+
Value value, const BufferizationOptions &options,
55+
const BufferizationState &state,
56+
function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
57+
58+
/// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
59+
/// associated with the value type.
60+
LogicalResult
61+
typesMatch(TensorLikeType tensor, BufferLikeType buffer,
62+
function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
63+
64+
/// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
65+
/// dialect associated with the value type.
66+
TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
67+
};
68+
69+
} // namespace bufferization
70+
} // namespace mlir
71+
72+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1516
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1617
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -386,20 +387,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
386387
// ToTensorOp
387388
//===----------------------------------------------------------------------===//
388389

390+
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391+
"specified tensor and buffer types match",
392+
CPred<
393+
"::mlir::bufferization::detail::typesMatchAfterBufferization("
394+
"$_op, $" # tensor # ", $" # buffer #")"
395+
>
396+
>;
397+
389398
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
390399
BufferizableOpInterface,
391400
SameOperandsAndResultShape,
392401
SameOperandsAndResultElementType,
393-
AllElementTypesMatch<["memref", "result"]>
402+
Bufferization_TensorAndBufferMatch<"result", "buffer">
394403
]> {
395-
let summary = "create a tensor from a `memref`";
404+
let summary = "create a buffer-like type from a tensor-like type";
396405
let description = [{
397-
An operation that creates a tensor from a `memref`. The result value is a
398-
tensor whose shape and element type match the memref operand.
406+
An operation that creates a tensor from a buffer. The result value is a
407+
tensor-like type whose shape and element type match the buffer-like operand.
399408

400409
The opposite of this op is `to_buffer`. Together, these two ops are
401410
useful for source/target materializations when doing type conversions
402-
involving tensors and memrefs.
411+
involving tensors and buffers.
403412

404413
Example:
405414

@@ -441,11 +450,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
441450
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
442451
}];
443452

444-
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
453+
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
445454
"the reference to load from",
446-
[MemReadAt<0, FullEffect>]>:$memref,
455+
[MemReadAt<0, FullEffect>]>:$buffer,
447456
UnitAttr:$restrict, UnitAttr:$writable);
448-
let results = (outs AnyTensor:$result);
457+
let results = (outs Bufferization_TensorLikeTypeInterface:$result);
449458

450459
let extraClassDeclaration = [{
451460
/// The result of a to_tensor is always a tensor.
@@ -472,19 +481,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
472481
FailureOr<BaseMemRefType> getBufferType(
473482
Value value, const BufferizationOptions &options,
474483
const BufferizationState &state, SmallVector<Value> &invocationStack) {
475-
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
484+
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
476485
}
477486
}];
478487

479488
let assemblyFormat = [{
480-
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481-
`:` type($memref) `to` type($result)
489+
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490+
`:` type($buffer) `to` type($result)
482491
}];
483492

484493
let builders = [
485-
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486-
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487-
build($_builder, $_state, rtt, memref, restrict, writeable);
494+
OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
495+
auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
496+
build($_builder, $_state, rtt, buffer, restrict, writeable);
488497
}]>
489498
];
490499

@@ -502,10 +511,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
502511
SameOperandsAndResultShape,
503512
SameOperandsAndResultElementType,
504513
Pure,
505-
AllShapesMatch<["memref", "tensor"]>,
506-
AllElementTypesMatch<["memref", "tensor"]>
514+
Bufferization_TensorAndBufferMatch<"tensor", "buffer">
507515
]> {
508-
let summary = "cast a tensor to memref";
516+
let summary = "cast a tensor-like type to buffer-like type";
509517
let description = [{
510518
An operation that returns the future buffer of a `tensor`.
511519

@@ -523,8 +531,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
523531
the returned buffer) will not be written to.
524532
}];
525533

526-
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
527-
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
534+
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
535+
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
528536

529537
let extraClassDeclaration = [{
530538
//===------------------------------------------------------------------===//
@@ -559,7 +567,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
559567
}];
560568

561569
let assemblyFormat = [{
562-
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
570+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
563571
}];
564572

565573
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
6565
// The operand was already bufferized. Take its type directly.
6666
callerType = memrefType;
6767
} else {
68-
FailureOr<BaseMemRefType> maybeCallerType =
68+
FailureOr<BufferLikeType> maybeCallerType =
6969
bufferization::getBufferType(opOperand->get(), options, state,
7070
invocationStack);
7171
if (failed(maybeCallerType))
7272
return failure();
73-
callerType = *maybeCallerType;
73+
assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74+
callerType = cast<BaseMemRefType>(*maybeCallerType);
7475
}
7576

7677
if (!bufferType) {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ struct SelectOpInterface
164164
// buffers have different types, they differ only in their layout map. Cast
165165
// both of them to the most dynamic MemRef type.
166166
if (trueBuffer.getType() != falseBuffer.getType()) {
167-
auto targetType =
168-
bufferization::getBufferType(selectOp.getResult(), options, state);
167+
auto targetType = bufferization::detail::castToMemRef(
168+
bufferization::getBufferType(selectOp.getResult(), options, state));
169169
if (failed(targetType))
170170
return failure();
171171
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
187187
SmallVector<Value> &invocationStack) const {
188188
auto selectOp = cast<arith::SelectOp>(op);
189189
assert(value == selectOp.getResult() && "invalid value");
190-
auto trueType = bufferization::getBufferType(
191-
selectOp.getTrueValue(), options, state, invocationStack);
192-
auto falseType = bufferization::getBufferType(
193-
selectOp.getFalseValue(), options, state, invocationStack);
190+
auto trueType =
191+
bufferization::detail::castToMemRef(bufferization::getBufferType(
192+
selectOp.getTrueValue(), options, state, invocationStack));
193+
auto falseType =
194+
bufferization::detail::castToMemRef(bufferization::getBufferType(
195+
selectOp.getFalseValue(), options, state, invocationStack));
194196
if (failed(trueType) || failed(falseType))
195197
return failure();
196198
if (*trueType == *falseType)

0 commit comments

Comments
 (0)