Skip to content

Commit c780184

Browse files
committed
[MLIR][Transform] Expose map layout option in OneShotBufferizeOp
Expose `function-boundary-type-conversion` in `OneShotBufferizeOp`. To reuse options between passes and transform operations, create a `BufferizationEnums.td`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D137833
1 parent 0a37290 commit c780184

File tree

13 files changed

+88
-24
lines changed

13 files changed

+88
-24
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "mlir/Support/LLVM.h"
1515
#include "llvm/ADT/SetVector.h"
1616

17+
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
18+
1719
namespace mlir {
1820
class OpBuilder;
1921

@@ -187,12 +189,6 @@ struct BufferizationOptions {
187189
using UnknownTypeConverterFn = std::function<BaseMemRefType(
188190
Value, unsigned, const BufferizationOptions &)>;
189191

190-
enum class LayoutMapOption : int8_t {
191-
InferLayoutMap = 0,
192-
IdentityLayoutMap = 1,
193-
FullyDynamicLayoutMap = 2
194-
};
195-
196192
BufferizationOptions();
197193

198194
/// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -585,6 +581,10 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
585581
} // namespace bufferization
586582
} // namespace mlir
587583

584+
//===----------------------------------------------------------------------===//
585+
// Bufferization Interfaces
586+
//===----------------------------------------------------------------------===//
587+
588588
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
589589

590590
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- BufferizationEnums.td - Bufferization enums ---------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the definition file for enums used in Bufferization.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef BUFFERIZATION_ENUMS
14+
#define BUFFERIZATION_ENUMS
15+
16+
include "mlir/IR/EnumAttr.td"
17+
18+
def LayoutMapOption : I32EnumAttr<"LayoutMapOption",
19+
"option for map layout", [
20+
I32EnumAttrCase<"InferLayoutMap", 0>,
21+
I32EnumAttrCase<"IdentityLayoutMap", 1>,
22+
I32EnumAttrCase<"FullyDynamicLayoutMap", 2>
23+
]> {
24+
let cppNamespace = "::mlir::bufferization";
25+
}
26+
27+
#endif // BUFFERIZATION_ENUMS

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@ add_mlir_dialect(BufferizationOps bufferization)
22
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
33
add_mlir_interface(AllocationOpInterface)
44
add_mlir_interface(BufferizableOpInterface)
5+
6+
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
7+
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
8+
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
9+
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
10+
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
1010
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
1111

12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1213
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
1314
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1415
#include "mlir/IR/OpImplementation.h"

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef BUFFERIZATION_TRANSFORM_OPS
1010
#define BUFFERIZATION_TRANSFORM_OPS
1111

12+
include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
1213
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1314
include "mlir/Dialect/Transform/IR/TransformEffects.td"
1415
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -42,6 +43,7 @@ def OneShotBufferizeOp
4243

4344
let arguments = (
4445
ins PDL_Operation:$target,
46+
OptionalAttr<LayoutMapOption>:$function_boundary_type_conversion,
4547
DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
4648
DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
4749
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
@@ -52,7 +54,10 @@ def OneShotBufferizeOp
5254

5355
let results = (outs);
5456

55-
let assemblyFormat = "$target attr-dict";
57+
let assemblyFormat = [{
58+
(`layout` `{` $function_boundary_type_conversion^ `}`)?
59+
$target attr-dict
60+
}];
5661
}
5762

5863
#endif // BUFFERIZATION_TRANSFORM_OPS

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
1010
DEPENDS
1111
MLIRAllocationOpInterfaceIncGen
1212
MLIRBufferizationOpsIncGen
13+
MLIRBufferizationEnumsIncGen
1314

1415
LINK_LIBS PUBLIC
1516
MLIRAffineDialect

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
3434
options.createDeallocs = getCreateDeallocs();
3535
options.testAnalysisOnly = getTestAnalysisOnly();
3636
options.printConflicts = getPrintConflicts();
37+
if (getFunctionBoundaryTypeConversion().has_value())
38+
options.functionBoundaryTypeConversion =
39+
*getFunctionBoundaryTypeConversion();
3740

3841
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
3942
for (Operation *target : payloadOps) {
@@ -94,6 +97,8 @@ class BufferizationTransformDialectExtension
9497
#define GET_OP_CLASSES
9598
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
9699

100+
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
101+
97102
void mlir::bufferization::registerTransformDialectExtension(
98103
DialectRegistry &registry) {
99104
registry.addExtensions<BufferizationTransformDialectExtension>();

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,13 @@ struct FinalizingBufferizePass
163163
}
164164
};
165165

166-
static BufferizationOptions::LayoutMapOption
167-
parseLayoutMapOption(const std::string &s) {
166+
static LayoutMapOption parseLayoutMapOption(const std::string &s) {
168167
if (s == "fully-dynamic-layout-map")
169-
return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
168+
return LayoutMapOption::FullyDynamicLayoutMap;
170169
if (s == "identity-layout-map")
171-
return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
170+
return LayoutMapOption::IdentityLayoutMap;
172171
if (s == "infer-layout-map")
173-
return BufferizationOptions::LayoutMapOption::InferLayoutMap;
172+
return LayoutMapOption::InferLayoutMap;
174173
llvm_unreachable("invalid layout map option");
175174
}
176175

@@ -216,19 +215,17 @@ struct OneShotBufferizePass
216215
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
217216

218217
// Configure type converter.
219-
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
218+
LayoutMapOption unknownTypeConversionOption =
220219
parseLayoutMapOption(unknownTypeConversion);
221220
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
222221
const BufferizationOptions &options) {
223222
auto tensorType = value.getType().cast<TensorType>();
224-
if (unknownTypeConversionOption ==
225-
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
223+
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
226224
return bufferization::getMemRefTypeWithStaticIdentityLayout(
227225
tensorType, memorySpace);
228-
assert(
229-
unknownTypeConversionOption ==
230-
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
231-
"invalid layout map option");
226+
assert(unknownTypeConversionOption ==
227+
LayoutMapOption::FullyDynamicLayoutMap &&
228+
"invalid layout map option");
232229
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
233230
memorySpace);
234231
};

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
1818

1919
DEPENDS
2020
MLIRBufferizationPassIncGen
21+
MLIRBufferizationEnumsIncGen
2122

2223
LINK_LIBS PUBLIC
2324
MLIRBufferizationDialect

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
6969

7070
BaseMemRefType memrefType;
7171
if (options.functionBoundaryTypeConversion ==
72-
BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
72+
LayoutMapOption::IdentityLayoutMap) {
7373
memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
7474
} else {
7575
// Note: Layout maps on function parameters cannot be inferred. The best we
@@ -471,7 +471,7 @@ struct FuncOpInterface
471471

472472
BaseMemRefType resultType;
473473
if (options.functionBoundaryTypeConversion ==
474-
BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
474+
LayoutMapOption::IdentityLayoutMap) {
475475
resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
476476
} else {
477477
// Note: If `InferLayoutMap`, cast are later folded away.

0 commit comments

Comments
 (0)