|
25 | 25 | #include "mlir/Interfaces/TilingInterface.h" |
26 | 26 | #include "mlir/Transforms/DialectConversion.h" |
27 | 27 | #include "llvm/ADT/SmallBitVector.h" |
28 | | -#include "llvm/ADT/SmallSet.h" |
29 | 28 |
|
30 | 29 | namespace mlir { |
31 | 30 | namespace bufferization { |
@@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, |
621 | 620 | /// In the future, more general interfaces can be devised to encode similar |
622 | 621 | /// shape evolutions and map between an op and its operands. |
623 | 622 | SmallVector<OpFoldResult> |
624 | | -computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v, |
| 623 | +computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v, |
625 | 624 | AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes, |
626 | 625 | const PadTilingInterfaceOptions &options); |
627 | 626 |
|
628 | 627 | using PadSizeComputationFunction = |
629 | 628 | std::function<FailureOr<SmallVector<OpFoldResult>>( |
630 | | - RewriterBase &, OpOperand &, ArrayRef<Range>, |
| 629 | + OpBuilder &, OpOperand &, ArrayRef<Range>, |
631 | 630 | const PadTilingInterfaceOptions &)>; |
632 | 631 |
|
633 | 632 | /// Specific helper for Linalg ops. |
634 | | -FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape( |
635 | | - RewriterBase &rewriter, OpOperand &operandToPad, |
636 | | - ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options); |
| 633 | +FailureOr<SmallVector<OpFoldResult>> |
| 634 | +computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad, |
| 635 | + ArrayRef<Range> iterationDomain, |
| 636 | + const PadTilingInterfaceOptions &); |
| 637 | + |
| 638 | +/// Operations and values created in the process of padding a TilingInterface |
| 639 | +/// operation. |
| 640 | +struct PadTilingInterfaceResult { |
| 641 | + /// The operands of the padded op. |
| 642 | + SmallVector<tensor::PadOp> padOps; |
| 643 | + /// The padded op, a clone of `toPad` with padded operands. |
| 644 | + TilingInterface paddedOp; |
| 645 | + /// Slices of the padded op's results, same types as `toPad`. |
| 646 | + SmallVector<Value> replacements; |
| 647 | +}; |
637 | 648 |
|
638 | | -/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`. |
639 | | -/// |
| 649 | +/// Pad the iterator dimensions of `toPad`. |
640 | 650 | /// * "options.paddingSizes" indicates that each padding dimension should be |
641 | 651 | /// padded to the specified padding size. |
642 | 652 | /// * "options.padToMultipleOf" indicates that the paddingSizes should be |
643 | 653 | // interpreted as the bounding box (dynamic) value to pad to. |
644 | 654 | /// * Use "options.paddingValues" to set the padding value of the created |
645 | 655 | // tensor::PadOp. |
646 | | -/// * The tensor::PadOp is returned on success. |
647 | | - |
648 | | -FailureOr<TilingInterface> |
649 | | -rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, |
650 | | - const PadTilingInterfaceOptions &constOptions, |
651 | | - SmallVector<tensor::PadOp> &padOps, |
652 | | - const PadSizeComputationFunction &computePaddingSizeFun = |
| 656 | +FailureOr<PadTilingInterfaceResult> |
| 657 | +rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad, |
| 658 | + PadTilingInterfaceOptions options, |
| 659 | + const PadSizeComputationFunction & = |
653 | 660 | &computeIndexingMapOpInterfacePaddedShape); |
654 | 661 |
|
655 | 662 | namespace detail { |
|
0 commit comments