1212include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+ include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1516include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1617include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
386387// ToTensorOp
387388//===----------------------------------------------------------------------===//
388389
390+ class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391+ "specified tensor and buffer types match",
392+ CPred<
393+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
394+ "$_op, $" # tensor # ", $" # buffer #")"
395+ >
396+ >;
397+
389398def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
390399 BufferizableOpInterface,
391400 SameOperandsAndResultShape,
392401 SameOperandsAndResultElementType,
393- AllElementTypesMatch<["memref ", "result"] >
402+ Bufferization_TensorAndBufferMatch<"result ", "buffer" >
394403 ]> {
395- let summary = "create a tensor from a `memref` ";
404+ let summary = "create a buffer-like type from a tensor-like type ";
396405 let description = [{
397- An operation that creates a tensor from a `memref`. The result value is a
398- tensor whose shape and element type match the memref operand.
406+ An operation that creates a tensor from a buffer. The result value is a
407+ tensor-like type that must match the corresponding buffer-like operand as
408+ per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409+ and BaseMemRefType), this means that shapes and element types match between
410+ the tensor and the buffer.
399411
400412 The opposite of this op is `to_buffer`. Together, these two ops are
401413 useful for source/target materializations when doing type conversions
402- involving tensors and memrefs .
414+ involving tensors and buffers .
403415
404416 Example:
405417
@@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
441453 away. However, such IR is no longer bufferizable with One-Shot Bufferize.
442454 }];
443455
444- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef ,
456+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface ,
445457 "the reference to load from",
446- [MemReadAt<0, FullEffect>]>:$memref ,
458+ [MemReadAt<0, FullEffect>]>:$buffer ,
447459 UnitAttr:$restrict, UnitAttr:$writable);
448- let results = (outs AnyTensor :$result);
460+ let results = (outs Bufferization_TensorLikeTypeInterface :$result);
449461
450462 let extraClassDeclaration = [{
451463 /// The result of a to_tensor is always a tensor.
452- TensorType getType() {
453- Type resultType = getResult().getType();
454- if (::llvm::isa<TensorType>(resultType))
455- return ::llvm::cast<TensorType>(resultType);
456- return {};
464+ ::mlir::bufferization::TensorLikeType getType() {
465+ return getResult().getType();
457466 }
458467
459468 //===------------------------------------------------------------------===//
@@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
472481 FailureOr<BaseMemRefType> getBufferType(
473482 Value value, const BufferizationOptions &options,
474483 const BufferizationState &state, SmallVector<Value> &invocationStack) {
475- return ::llvm::cast<BaseMemRefType>(getMemref ().getType());
484+ return ::llvm::cast<BaseMemRefType>(getBuffer ().getType());
476485 }
477486 }];
478487
479488 let assemblyFormat = [{
480- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481- `:` type($memref ) `to` type($result)
489+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490+ `:` type($buffer ) `to` type($result)
482491 }];
483492
484- let builders = [
485- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487- build($_builder, $_state, rtt, memref, restrict, writeable);
488- }]>
489- ];
490-
491493 let hasCanonicalizer = 1;
492494 let hasFolder = 1;
493495}
@@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
502504 SameOperandsAndResultShape,
503505 SameOperandsAndResultElementType,
504506 Pure,
505- AllShapesMatch<["memref", "tensor"]>,
506- AllElementTypesMatch<["memref", "tensor"]>
507+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
507508 ]> {
508- let summary = "cast a tensor to memref ";
509+ let summary = "cast a tensor-like type to buffer-like type ";
509510 let description = [{
510511 An operation that returns the future buffer of a `tensor`.
511512
@@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
523524 the returned buffer) will not be written to.
524525 }];
525526
526- let arguments = (ins AnyTensor :$tensor, UnitAttr:$read_only);
527- let results = (outs AnyRankedOrUnrankedMemRef:$memref );
527+ let arguments = (ins Bufferization_TensorLikeTypeInterface :$tensor, UnitAttr:$read_only);
528+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer );
528529
529530 let extraClassDeclaration = [{
530531 //===------------------------------------------------------------------===//
@@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
559560 }];
560561
561562 let assemblyFormat = [{
562- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref )
563+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer )
563564 }];
564565
565566 let hasFolder = 1;
0 commit comments