1212include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15- include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1615include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1716include "mlir/Interfaces/DestinationStyleOpInterface.td"
1817include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,31 +386,20 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
387386// ToTensorOp
388387//===----------------------------------------------------------------------===//
389388
390- class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391- "specified tensor and buffer types match",
392- CPred<
393- "::mlir::bufferization::detail::typesMatchAfterBufferization("
394- "$_op, $" # tensor # ", $" # buffer #")"
395- >
396- >;
397-
398389def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
399390 BufferizableOpInterface,
400391 SameOperandsAndResultShape,
401392 SameOperandsAndResultElementType,
402- Bufferization_TensorAndBufferMatch<"result ", "buffer" >
393+ AllElementTypesMatch<["memref ", "result"] >
403394 ]> {
404- let summary = "create a buffer-like type from a tensor-like type ";
395+ let summary = "create a tensor from a `memref` ";
405396 let description = [{
406- An operation that creates a tensor from a buffer. The result value is a
407- tensor-like type that must match the corresponding buffer-like operand as
408- per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409- and BaseMemRefType), this means that shapes and element types match between
410- the tensor and the buffer.
397+ An operation that creates a tensor from a `memref`. The result value is a
398+ tensor whose shape and element type match the memref operand.
411399
412400 The opposite of this op is `to_buffer`. Together, these two ops are
413401 useful for source/target materializations when doing type conversions
414- involving tensors and buffers .
402+ involving tensors and memrefs .
415403
416404 Example:
417405
@@ -453,16 +441,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
453441 away. However, such IR is no longer bufferizable with One-Shot Bufferize.
454442 }];
455443
456- let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface ,
444+ let arguments = (ins Arg<AnyRankedOrUnrankedMemRef ,
457445 "the reference to load from",
458- [MemReadAt<0, FullEffect>]>:$buffer ,
446+ [MemReadAt<0, FullEffect>]>:$memref ,
459447 UnitAttr:$restrict, UnitAttr:$writable);
460- let results = (outs Bufferization_TensorLikeTypeInterface :$result);
448+ let results = (outs AnyTensor :$result);
461449
462450 let extraClassDeclaration = [{
463451 /// The result of a to_tensor is always a tensor.
464- ::mlir::bufferization::TensorLikeType getType() {
465- return getResult().getType();
452+ TensorType getType() {
453+ Type resultType = getResult().getType();
454+ if (::llvm::isa<TensorType>(resultType))
455+ return ::llvm::cast<TensorType>(resultType);
456+ return {};
466457 }
467458
468459 //===------------------------------------------------------------------===//
@@ -481,15 +472,22 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
481472 FailureOr<BaseMemRefType> getBufferType(
482473 Value value, const BufferizationOptions &options,
483474 const BufferizationState &state, SmallVector<Value> &invocationStack) {
484- return ::llvm::cast<BaseMemRefType>(getBuffer ().getType());
475+ return ::llvm::cast<BaseMemRefType>(getMemref ().getType());
485476 }
486477 }];
487478
488479 let assemblyFormat = [{
489- $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490- `:` type($buffer ) `to` type($result)
480+ $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481+ `:` type($memref ) `to` type($result)
491482 }];
492483
484+ let builders = [
485+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487+ build($_builder, $_state, rtt, memref, restrict, writeable);
488+ }]>
489+ ];
490+
493491 let hasCanonicalizer = 1;
494492 let hasFolder = 1;
495493}
@@ -504,9 +502,10 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
504502 SameOperandsAndResultShape,
505503 SameOperandsAndResultElementType,
506504 Pure,
507- Bufferization_TensorAndBufferMatch<"tensor", "buffer">
505+ AllShapesMatch<["memref", "tensor"]>,
506+ AllElementTypesMatch<["memref", "tensor"]>
508507 ]> {
509- let summary = "cast a tensor-like type to buffer-like type ";
508+ let summary = "cast a tensor to memref ";
510509 let description = [{
511510 An operation that returns the future buffer of a `tensor`.
512511
@@ -524,8 +523,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
524523 the returned buffer) will not be written to.
525524 }];
526525
527- let arguments = (ins Bufferization_TensorLikeTypeInterface :$tensor, UnitAttr:$read_only);
528- let results = (outs Bufferization_BufferLikeTypeInterface:$buffer );
526+ let arguments = (ins AnyTensor :$tensor, UnitAttr:$read_only);
527+ let results = (outs AnyRankedOrUnrankedMemRef:$memref );
529528
530529 let extraClassDeclaration = [{
531530 //===------------------------------------------------------------------===//
@@ -560,7 +559,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
560559 }];
561560
562561 let assemblyFormat = [{
563- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer )
562+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref )
564563 }];
565564
566565 let hasFolder = 1;
0 commit comments