Skip to content

Commit 4940643

Browse files
andrey-golubevermilindwalekar
authored andcommitted
[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (#144867)
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are updated accordingly. Relates to ee070d0.
1 parent ac08a15 commit 4940643

File tree

8 files changed

+112
-44
lines changed

8 files changed

+112
-44
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
480480
FailureOr<BufferLikeType> getBufferType(
481481
Value value, const BufferizationOptions &options,
482482
SmallVector<Value> &invocationStack) {
483-
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
483+
return getBuffer().getType();
484484
}
485485
}];
486486

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
namespace mlir::bufferization {
2020
struct BufferizationOptions;
21-
class BufferizationState;
2221
class BufferLikeType;
2322
} // namespace mlir::bufferization
2423

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -953,8 +953,10 @@ FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
953953
auto tensorType = cast<TensorType>(value.getType());
954954

955955
// No further analysis is possible for a block argument.
956-
if (llvm::isa<BlockArgument>(value))
957-
return bufferization::getMemRefType(tensorType, options);
956+
if (llvm::isa<BlockArgument>(value)) {
957+
return cast<BufferLikeType>(
958+
bufferization::getMemRefType(tensorType, options));
959+
}
958960

959961
// Value is an OpResult.
960962
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,7 @@ FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
966968
// If the OpResult has an equivalent OpOperand, both OpResult and
967969
// OpOperand bufferize to the exact same buffer type.
968970
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
969-
return asMemRefType(
970-
getBufferType(equivalentOperand, options, invocationStack));
971+
return getBufferType(equivalentOperand, options, invocationStack);
971972
}
972973

973974
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +978,8 @@ FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
977978
if (!memSpace.has_value())
978979
return op->emitError("could not infer memory space");
979980

980-
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
981+
return cast<BufferLikeType>(
982+
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
981983
}
982984

983985
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,8 @@ static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
524524
const BufferizationOptions &options, const BufferizationState &state,
525525
SmallVector<Value> &invocationStack) {
526526
// Determine the buffer type of the init_arg.
527-
auto initArgBufferType = bufferization::detail::asMemRefType(
528-
bufferization::getBufferType(initArg, options, invocationStack));
527+
auto initArgBufferType =
528+
bufferization::getBufferType(initArg, options, invocationStack);
529529
if (failed(initArgBufferType))
530530
return failure();
531531

@@ -551,8 +551,8 @@ static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
551551
} else {
552552
// Note: This typically triggers a recursive call for the buffer type of
553553
// the iter_arg.
554-
auto maybeBufferType = bufferization::detail::asMemRefType(
555-
bufferization::getBufferType(yieldedValue, options, invocationStack));
554+
auto maybeBufferType =
555+
bufferization::getBufferType(yieldedValue, options, invocationStack);
556556
if (failed(maybeBufferType))
557557
return failure();
558558
yieldedValueBufferType = *maybeBufferType;
@@ -715,12 +715,7 @@ struct ForOpInterface
715715
if (auto opResult = dyn_cast<OpResult>(value)) {
716716
// The type of an OpResult must match the corresponding iter_arg type.
717717
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
718-
auto bufferType =
719-
bufferization::getBufferType(bbArg, options, invocationStack);
720-
if (failed(bufferType))
721-
return failure();
722-
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
723-
return cast<BaseMemRefType>(*bufferType);
718+
return bufferization::getBufferType(bbArg, options, invocationStack);
724719
}
725720

726721
// Compute result/argument number.
@@ -1079,8 +1074,8 @@ struct WhileOpInterface
10791074
// scf.condition was already bufferized.
10801075
return cast<BufferLikeType>(conditionYieldedVal.getType());
10811076
}
1082-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1083-
conditionYieldedVal, options, invocationStack));
1077+
return bufferization::getBufferType(conditionYieldedVal, options,
1078+
invocationStack);
10841079
}
10851080

10861081
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1308,14 +1303,14 @@ struct ForallOpInterface
13081303
if (auto bbArg = dyn_cast<BlockArgument>(value))
13091304
// A tensor block argument has the same bufferized type as the
13101305
// corresponding output operand.
1311-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1312-
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack));
1306+
return bufferization::getBufferType(
1307+
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
13131308

13141309
// The bufferized result type is the same as the bufferized type of the
13151310
// corresponding output operand.
1316-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1311+
return bufferization::getBufferType(
13171312
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1318-
invocationStack));
1313+
invocationStack);
13191314
}
13201315

13211316
bool isRepetitiveRegion(Operation *op, unsigned index) const {

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct CastOpInterface
8282
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
8383
return cast<BufferLikeType>(MemRefType::get(
8484
rankedResultType.getShape(), rankedResultType.getElementType(),
85-
llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
85+
llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
8686
}
8787

8888
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -503,8 +503,8 @@ struct FromElementsOpInterface
503503
/*copy=*/false);
504504
if (failed(tensorAlloc))
505505
return failure();
506-
FailureOr<BaseMemRefType> memrefType = bufferization::detail::asMemRefType(
507-
bufferization::getBufferType(*tensorAlloc, options));
506+
FailureOr<BufferLikeType> memrefType =
507+
bufferization::getBufferType(*tensorAlloc, options);
508508
if (failed(memrefType))
509509
return failure();
510510
Value buffer = rewriter.create<bufferization::ToBufferOp>(

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
272272

273273
// -----
274274

275-
// CHECK-LABEL: func.func @test_dialect_op(
275+
// CHECK: func.func @custom_op(
276276
// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
277277
// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> {
278-
func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
278+
func.func @custom_op(%arg: !test.test_tensor<[32, 64], f64>)
279279
-> !test.test_tensor<[32, 128], f64> {
280280
// CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
281281
// CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
@@ -288,3 +288,22 @@ func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
288288
// CHECK: return %[[OUT]]
289289
return %out : !test.test_tensor<[32, 128], f64>
290290
}
291+
292+
// -----
293+
294+
// CHECK: func.func @custom_origin_op()
295+
// CHECK-SAME: -> !test.test_tensor<[42], f64> {
296+
func.func @custom_origin_op() -> !test.test_tensor<[42], f64> {
297+
// CHECK: %[[MEMREF:.*]] = "test.create_memref_op"() : ()
298+
// CHECK-SAME: -> !test.test_memref<[21], f64>
299+
// CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
300+
// CHECK-SAME: : (!test.test_memref<[21], f64>)
301+
// CHECK-SAME: -> !test.test_memref<[42], f64>
302+
%in = "test.create_tensor_op"() : () -> !test.test_tensor<[21], f64>
303+
%out = "test.dummy_tensor_op"(%in) : (!test.test_tensor<[21], f64>)
304+
-> !test.test_tensor<[42], f64>
305+
306+
// CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
307+
// CHECK: return %[[OUT]]
308+
return %out : !test.test_tensor<[42], f64>
309+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,3 +1418,35 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
14181418

14191419
return mlir::success();
14201420
}
1421+
1422+
::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
1423+
::mlir::RewriterBase &rewriter,
1424+
const ::mlir::bufferization::BufferizationOptions &options) {
1425+
// Note: mlir::bufferization::getBufferType() would internally call
1426+
// TestCreateTensorOp::getBufferType()
1427+
const auto bufferizedOutType =
1428+
mlir::bufferization::getBufferType(getOutput(), options);
1429+
if (mlir::failed(bufferizedOutType))
1430+
return failure();
1431+
1432+
// replace op with memref analogy
1433+
auto createMemrefOp =
1434+
rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
1435+
1436+
mlir::bufferization::replaceOpWithBufferizedValues(
1437+
rewriter, getOperation(), createMemrefOp.getResult());
1438+
1439+
return mlir::success();
1440+
}
1441+
1442+
mlir::FailureOr<mlir::bufferization::BufferLikeType>
1443+
test::TestCreateTensorOp::getBufferType(
1444+
mlir::Value value, const mlir::bufferization::BufferizationOptions &,
1445+
llvm::SmallVector<::mlir::Value> &) {
1446+
const auto type = dyn_cast<test::TestTensorType>(value.getType());
1447+
if (type == nullptr)
1448+
return failure();
1449+
1450+
return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
1451+
getContext(), type.getShape(), type.getElementType(), nullptr));
1452+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,28 +3603,16 @@ def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
36033603
// Test Ops bufferization
36043604
//===----------------------------------------------------------------------===//
36053605

3606-
def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
3606+
def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
3607+
[DeclareOpInterfaceMethods<BufferizableOpInterface,
3608+
["bufferize", "bufferizesToMemoryRead",
3609+
"bufferizesToMemoryWrite", "getAliasingValues"]>]> {
36073610
let arguments = (ins
36083611
Arg<TestTensorType>:$input
36093612
);
36103613
let results = (outs
36113614
Arg<TestTensorType>:$output
36123615
);
3613-
let extraClassDeclaration = [{
3614-
// BufferizableOpInterface
3615-
bool bufferizesToMemoryRead(mlir::OpOperand&,
3616-
const mlir::bufferization::AnalysisState&);
3617-
3618-
bool bufferizesToMemoryWrite(mlir::OpOperand&,
3619-
const mlir::bufferization::AnalysisState&);
3620-
3621-
mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
3622-
const mlir::bufferization::AnalysisState&);
3623-
3624-
mlir::LogicalResult bufferize(
3625-
mlir::RewriterBase& rewriter,
3626-
const mlir::bufferization::BufferizationOptions& options);
3627-
}];
36283616

36293617
let extraClassDefinition = [{
36303618
bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
@@ -3652,4 +3640,37 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
36523640
);
36533641
}
36543642

3643+
def TestCreateTensorOp : TEST_Op<"create_tensor_op",
3644+
[DeclareOpInterfaceMethods<BufferizableOpInterface,
3645+
["bufferize", "getBufferType", "bufferizesToMemoryRead",
3646+
"bufferizesToMemoryWrite", "getAliasingValues",
3647+
"bufferizesToAllocation"]>]> {
3648+
let arguments = (ins);
3649+
let results = (outs Arg<TestTensorType>:$output);
3650+
let extraClassDefinition = [{
3651+
bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
3652+
const ::mlir::bufferization::AnalysisState&) {
3653+
return true;
3654+
}
3655+
bool test::TestCreateTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
3656+
const ::mlir::bufferization::AnalysisState&) {
3657+
return true;
3658+
}
3659+
bool test::TestCreateTensorOp::bufferizesToAllocation(mlir::Value value) {
3660+
return false;
3661+
}
3662+
3663+
::mlir::bufferization::AliasingValueList
3664+
test::TestCreateTensorOp::getAliasingValues(::mlir::OpOperand&,
3665+
const ::mlir::bufferization::AnalysisState&) {
3666+
return {};
3667+
}
3668+
}];
3669+
}
3670+
3671+
def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
3672+
let arguments = (ins);
3673+
let results = (outs Arg<TestMemrefType>:$output);
3674+
}
3675+
36553676
#endif // TEST_OPS

0 commit comments

Comments
 (0)