Skip to content

Commit 1508a8e

Browse files
authored
[MLIR][Linalg] Modify rewriteAsPaddedOp to not remove pre-padded op (llvm#163467)
Refactor/redesign `FailureOr<TilingInterface> rewriteAsPaddedOp(...)` to not remove unpadded operation. This is more in line with how other transformations like tiling work, where the user of the transformation decides when to replace the actual operation. Instead of this, return all info as a struct. --------- Signed-off-by: James Newling <[email protected]>
1 parent df89564 commit 1508a8e

File tree

3 files changed

+97
-96
lines changed

3 files changed

+97
-96
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "mlir/Interfaces/TilingInterface.h"
2626
#include "mlir/Transforms/DialectConversion.h"
2727
#include "llvm/ADT/SmallBitVector.h"
28-
#include "llvm/ADT/SmallSet.h"
2928

3029
namespace mlir {
3130
namespace bufferization {
@@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
621620
/// In the future, more general interfaces can be devised to encode similar
622621
/// shape evolutions and map between an op and its operands.
623622
SmallVector<OpFoldResult>
624-
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
623+
computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
625624
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
626625
const PadTilingInterfaceOptions &options);
627626

628627
using PadSizeComputationFunction =
629628
std::function<FailureOr<SmallVector<OpFoldResult>>(
630-
RewriterBase &, OpOperand &, ArrayRef<Range>,
629+
OpBuilder &, OpOperand &, ArrayRef<Range>,
631630
const PadTilingInterfaceOptions &)>;
632631

633632
/// Specific helper for Linalg ops.
634-
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
635-
RewriterBase &rewriter, OpOperand &operandToPad,
636-
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
633+
FailureOr<SmallVector<OpFoldResult>>
634+
computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
635+
ArrayRef<Range> iterationDomain,
636+
const PadTilingInterfaceOptions &);
637+
638+
/// Operations and values created in the process of padding a TilingInterface
639+
/// operation.
640+
struct PadTilingInterfaceResult {
641+
/// The operands of the padded op.
642+
SmallVector<tensor::PadOp> padOps;
643+
/// The padded op, a clone of `toPad` with padded operands.
644+
TilingInterface paddedOp;
645+
/// Slices of the padded op's results, same types as `toPad`.
646+
SmallVector<Value> replacements;
647+
};
637648

638-
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
639-
///
649+
/// Pad the iterator dimensions of `toPad`.
640650
/// * "options.paddingSizes" indicates that each padding dimension should be
641651
/// padded to the specified padding size.
642652
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
643653
// interpreted as the bounding box (dynamic) value to pad to.
644654
/// * Use "options.paddingValues" to set the padding value of the created
645655
// tensor::PadOp.
646-
/// * The tensor::PadOp is returned on success.
647-
648-
FailureOr<TilingInterface>
649-
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
650-
const PadTilingInterfaceOptions &constOptions,
651-
SmallVector<tensor::PadOp> &padOps,
652-
const PadSizeComputationFunction &computePaddingSizeFun =
656+
FailureOr<PadTilingInterfaceResult>
657+
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
658+
PadTilingInterfaceOptions options,
659+
const PadSizeComputationFunction & =
653660
&computeIndexingMapOpInterfacePaddedShape);
654661

655662
namespace detail {

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
24572457
}
24582458

24592459
// Set options.
2460-
TilingInterface paddedOp;
24612460
PadTilingInterfaceOptions options;
24622461
options.setPaddingValues(paddingValues)
24632462
.setPaddingSizes(getMixedPaddingSizes())
24642463
.setPadToMultipleOf(getPadToMultipleOf());
24652464

2466-
// Apply padding.
2467-
SmallVector<tensor::PadOp> newPadOps;
2468-
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2469-
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2470-
newPadOps);
2471-
if (failed(maybePaddedOp)) {
2465+
auto maybePadOps = rewriteAsPaddedOp(
2466+
rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2467+
if (failed(maybePadOps)) {
24722468
auto diag = emitSilenceableError() << "failed to pad op";
24732469
diag.attachNote(target->getLoc()) << "target op";
24742470
return diag;
24752471
}
2472+
const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
24762473

24772474
// Set transform results.
2478-
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2479-
padOps.append(newPadOps.begin(), newPadOps.end());
2475+
paddedOps.push_back(paddedOp);
2476+
padOps.append(paddedOperands.begin(), paddedOperands.end());
2477+
rewriter.replaceOp(targetOp.getOperation(), slicedResults);
24802478
}
24812479

24822480
results.set(cast<OpResult>(getPadded()), paddedOps);

0 commit comments

Comments
 (0)