1212include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+ include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1516include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1617include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -383,20 +384,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
383384// ToTensorOp
384385//===----------------------------------------------------------------------===//
385386
387+ class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
388+ "specified tensor and buffer types match",
389+ CPred<
390+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
391+ "$_op, $" # tensor # ", $" # buffer #")"
392+ >
393+ >;
394+
386395def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
387396 BufferizableOpInterface,
388397 SameOperandsAndResultShape,
389398 SameOperandsAndResultElementType,
390- AllElementTypesMatch<["memref ", "result"] >
399+ Bufferization_TensorAndBufferMatch<"result ", "buffer" >
391400 ]> {
392- let summary = "create a tensor from a `memref` ";
401+ let summary = "create a buffer-like type from a tensor-like type ";
393402 let description = [{
394- An operation that creates a tensor from a `memref`. The result value is a
395- tensor whose shape and element type match the memref operand.
403+ An operation that creates a tensor from a buffer. The result value is a
404+ tensor-like type that must match the corresponding buffer-like operand as
405+ per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
406+ and BaseMemRefType), this means that shapes and element types match between
407+ the tensor and the buffer.
396408
397409 The opposite of this op is `to_buffer`. Together, these two ops are
398410 useful for source/target materializations when doing type conversions
399- involving tensors and memrefs .
411+ involving tensors and buffers .
400412
401413 Example:
402414
@@ -438,19 +450,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
438450 away. However, such IR is no longer bufferizable with One-Shot Bufferize.
439451 }];
440452
441- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef ,
453+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface ,
442454 "the reference to load from",
443- [MemReadAt<0, FullEffect>]>:$memref ,
455+ [MemReadAt<0, FullEffect>]>:$buffer ,
444456 UnitAttr:$restrict, UnitAttr:$writable);
445- let results = (outs AnyTensor :$result);
457+ let results = (outs Bufferization_TensorLikeTypeInterface :$result);
446458
447459 let extraClassDeclaration = [{
448460 /// The result of a to_tensor is always a tensor.
449- TensorType getType() {
450- Type resultType = getResult().getType();
451- if (::llvm::isa<TensorType>(resultType))
452- return ::llvm::cast<TensorType>(resultType);
453- return {};
461+ ::mlir::bufferization::TensorLikeType getType() {
462+ return getResult().getType();
454463 }
455464
456465 //===------------------------------------------------------------------===//
@@ -468,22 +477,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
468477 FailureOr<BaseMemRefType> getBufferType(
469478 Value value, const BufferizationOptions &options,
470479 SmallVector<Value> &invocationStack) {
471- return ::llvm::cast<BaseMemRefType>(getMemref ().getType());
480+ return ::llvm::cast<BaseMemRefType>(getBuffer ().getType());
472481 }
473482 }];
474483
475484 let assemblyFormat = [{
476- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
477- `:` type($memref ) `to` type($result)
485+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
486+ `:` type($buffer ) `to` type($result)
478487 }];
479488
480- let builders = [
481- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
482- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
483- build($_builder, $_state, rtt, memref, restrict, writeable);
484- }]>
485- ];
486-
487489 let hasCanonicalizer = 1;
488490 let hasFolder = 1;
489491}
@@ -498,10 +500,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
498500 SameOperandsAndResultShape,
499501 SameOperandsAndResultElementType,
500502 Pure,
501- AllShapesMatch<["memref", "tensor"]>,
502- AllElementTypesMatch<["memref", "tensor"]>
503+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
503504 ]> {
504- let summary = "cast a tensor to memref ";
505+ let summary = "cast a tensor-like type to buffer-like type ";
505506 let description = [{
506507 An operation that returns the future buffer of a `tensor`.
507508
@@ -519,8 +520,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
519520 the returned buffer) will not be written to.
520521 }];
521522
522- let arguments = (ins AnyTensor :$tensor, UnitAttr:$read_only);
523- let results = (outs AnyRankedOrUnrankedMemRef:$memref );
523+ let arguments = (ins Bufferization_TensorLikeTypeInterface :$tensor, UnitAttr:$read_only);
524+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer );
524525
525526 let extraClassDeclaration = [{
526527 //===------------------------------------------------------------------===//
@@ -554,7 +555,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
554555 }];
555556
556557 let assemblyFormat = [{
557- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref )
558+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer )
558559 }];
559560
560561 let hasFolder = 1;
0 commit comments