From c52d046a2cf004f81b8189d33665ec48082fe18f Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 18 Jun 2025 12:17:35 +0000 Subject: [PATCH 1/2] [mlir][linalg] Add sizeToPadTo option to linalg::LinalgPaddingOptions --- .../Dialect/Linalg/Transforms/Transforms.h | 17 ++ .../include/mlir/Dialect/Linalg/Utils/Utils.h | 17 +- .../lib/Dialect/Linalg/Transforms/Padding.cpp | 163 ++++++++++++++---- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 24 ++- .../test/Dialect/Linalg/transform-op-pad.mlir | 4 +- 5 files changed, 172 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2eef0a06d0eb4..147a2907f52e4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -295,6 +295,23 @@ struct LinalgPaddingOptions { padToMultipleOf.emplace(m.begin(), m.end()); return *this; } + /// A mapping between an operand and shape dim, and a size for a padding + /// dimension. Each size is expected to be greater or equal than the + /// corresponding shape dim. If no value is provided then the constant upper + /// bound will be used. + DenseMap, OpFoldResult> sizeToPadTo; + LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex, + OpFoldResult size) { + assert(size && "expected non-null size"); + sizeToPadTo[{operandIndex, dimIndex}] = size; + return *this; + } + /// Given the operand index and shape dim it returns the size to pad to. + OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const { + return sizeToPadTo.lookup_or( + std::pair(operandIndex, dimIndex), nullptr); + } + /// A flag for every operand to mark the PadOp as nofold which enables /// packing for statically shaped operands. SmallVector nofoldFlags; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 80aa034d2199d..fc151d02ceef6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType); /// Check if iterator type has "reduction" semantics. bool isReductionIterator(utils::IteratorType iteratorType); -/// Create a tensor::PadOp that pads `source` to the size of the statically -/// sized `type` whose static sizes are assumed to be greater than the dynamic -/// `source` size. The padding introduces trailing `pad` values until the -/// target size is met. If `source` is defined by one or more LinalgOps that -/// have been padded with the same value and sizes, return their padded result -/// instead of creating a tensor::PadOp. +/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes +/// are assumed to be greater than the dynamic `source` size. If `typeDynDims` +/// is specified, then it must contain the sizes of all the dynamic dimensions +/// in order of appearance in `type`, otherwise the function will pad those +/// values to `0`. The padding introduces trailing `pad` values until the target +/// size is met. If `source` is defined by one or more LinalgOps that have been +/// padded with the same value and sizes, return their padded result instead of +/// creating a tensor::PadOp. /// /// Example: /// ``` @@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType); /// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst } /// ``` Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, - Value source, Value pad, bool nofold); + Value source, Value padding, bool nofold, + ValueRange typeDynDims = std::nullopt); /// Returns GenericOp that copies an n-D memref. Unlike the current /// implementation of memref::CopyOp, this op can further tile, lower to loops diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index 9a685f6dc96ac..dc9e11eccac4d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -22,53 +23,93 @@ using namespace mlir::linalg; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DBGSNL() (llvm::dbgs() << "\n") -/// Compute the padded shape of the given operand. The operand is padded to a -/// static bounding box according to the specified padding options. -static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, +namespace { +/// Helper class for storing padding information. +struct PaddingInfo { + PaddingInfo(int64_t padToMultipleOf = 1, OpFoldResult size = {}) + : padToMultipleOf(padToMultipleOf), size(size) {} + /// Pad the tensor to a multiple of. + int64_t padToMultipleOf = 1; + /// The size used for padding. + OpFoldResult size = {}; +}; + +/// Helper class for storing and computing the padded shape. +struct PaddedShape { + /// Initializes the shape information and on success it returns whether the + /// shape of the operand will change. Returns failure if the operand cannot be + /// padded. + FailureOr initialize(linalg::LinalgOp opToPad, OpOperand *opOperand, + const LinalgPaddingOptions &options); + + /// Computs the padded shape. + void computePadding(OpBuilder &builder, Value operand); + + /// Returns the new tensor type. + RankedTensorType getType(Type elemTy) { + return RankedTensorType::get(shape, elemTy); + } + + SmallVector dynDims; + +private: + SmallVector shape; + DenseMap dimToInfo; +}; +} // namespace + +FailureOr PaddedShape::initialize(linalg::LinalgOp opToPad, OpOperand *opOperand, - const LinalgPaddingOptions &options, - SmallVector &paddedShape, - bool &alreadyHasRequestedShape) { + const LinalgPaddingOptions &options) { AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); - ArrayRef shape = opToPad.getShape(opOperand); + + // Initialize the padded shape. + llvm::append_range(shape, opToPad.getShape(opOperand)); // Collect the shape dimensions that are a function of "paddingDimensions", // along with the multiple that they should be padded to ("1" if none). - alreadyHasRequestedShape = true; - DenseMap shapeDimToMultiple; + bool alreadyHasRequestedShape = true; for (const auto &dimEn : enumerate(options.paddingDimensions)) { for (const auto &en : enumerate(indexingMap.getResults())) { if (en.value().isFunctionOfDim(dimEn.value())) { + PaddingInfo paddingInfo; int64_t dimSize = shape[en.index()]; if (options.padToMultipleOf.has_value()) { - shapeDimToMultiple[en.index()] = + paddingInfo.padToMultipleOf = (*options.padToMultipleOf)[dimEn.index()]; } else { - shapeDimToMultiple[en.index()] = 1; + paddingInfo.padToMultipleOf = 1; } - if (ShapedType::isDynamic(dimSize)) { - alreadyHasRequestedShape = false; - } else if (dimSize % shapeDimToMultiple[en.index()] != 0) { + + // Check if the user provided a size in the options. + paddingInfo.size = + options.getSizeToPadTo(opOperand->getOperandNumber(), en.index()); + + // Set the padding info. + dimToInfo[en.index()] = paddingInfo; + if (ShapedType::isDynamic(dimSize) || + dimSize % paddingInfo.padToMultipleOf != 0 || + !paddingInfo.size.isNull()) { alreadyHasRequestedShape = false; } } } } - // Helper function to round a number up to a given multiple. - auto ceil = [](int64_t val, int64_t multiple) { - return ((val + multiple - 1) / multiple) * multiple; - }; - // Upper bound the sizes to obtain a static bounding box. - paddedShape.assign(shape.begin(), shape.end()); for (int64_t i = 0, e = shape.size(); i < e; ++i) { - LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n"); + LLVM_DEBUG(DBGS() << "--computing un-padded size for dim " << i << "\n"); // Skip dimensions that do not require padding. - if (!shapeDimToMultiple.contains(i)) { + if (!dimToInfo.contains(i)) { LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n"); continue; } + PaddingInfo &info = dimToInfo[i]; + if (info.size) { + LLVM_DEBUG(DBGS() << "----the user provided the size: " << info.size + << "\n"); + continue; + } // Otherwise, try to compute a constant upper bound for the size value. FailureOr upperBound = ValueBoundsConstraintSet::computeConstantBound( @@ -77,14 +118,58 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, /*dim=*/i}, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { - LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); + LLVM_DEBUG( + DBGS() << "----could not compute a bounding box for padding\n"); return failure(); } - paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]); - LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n"); + info.size = + IntegerAttr::get(IndexType::get(opToPad.getContext()), *upperBound); + LLVM_DEBUG(DBGS() << "----new un-padded size: " << info.size << "\n"); } + return alreadyHasRequestedShape; +} - return success(); +void PaddedShape::computePadding(OpBuilder &builder, Value operand) { + Location loc = operand.getLoc(); + AffineExpr sizeSym = builder.getAffineSymbolExpr(0); + + // Compute the padding for each dimension. + for (auto &&[i, dim] : llvm::enumerate(shape)) { + LLVM_DEBUG(DBGS() << "--computing padded size for dim " << i << "\n"); + + // Get the padding info or default info for the shape dimension. + PaddingInfo paddingInfo = dimToInfo.lookup(i); + + // Skip dimensions that do not require padding. + if (paddingInfo.size.isNull()) { + LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n"); + + // We still need to push the size as `makeComposedPadHighOp` expects a + // range with all the dynamic sizes, whether they're being padded or not. + if (ShapedType::isDynamic(dim)) { + dynDims.push_back( + cast(tensor::getMixedSize(builder, loc, operand, i))); + } + continue; + } + + // Compute the padded size to be a multiple of `padToMultipleOf`. + AffineExpr szExpr = (sizeSym).ceilDiv(paddingInfo.padToMultipleOf) * + paddingInfo.padToMultipleOf; + OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + builder, loc, szExpr, paddingInfo.size); + assert(paddedSize && "invalid arguments to affine apply"); + + if (auto cstSzAttr = dyn_cast(paddedSize)) { + // Update the shape as the size is static. + dim = cast(cstSzAttr).getValue().getZExtValue(); + } else { + // Add a dynamic dimension. + dim = ShapedType::kDynamic; + dynDims.push_back(cast(paddedSize)); + } + LLVM_DEBUG(DBGS() << "----new dim size: " << paddedSize << "\n"); + } } /// Pad the `opOperand` in the "paddingDimensions" using the padding value and @@ -107,20 +192,21 @@ static FailureOr padOperandToSmallestStaticBoundingBox( options.padToMultipleOf->size() == options.paddingDimensions.size()) && "invalid number of elements in padToMultipleOf"); - // Compute padded shape. - SmallVector paddedShape; - bool alreadyHasRequestedShape = false; - if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape, - alreadyHasRequestedShape))) + // Initialize the padded shape and get whether it requires padding. + PaddedShape shape; + FailureOr alreadyHasRequestedShape = + shape.initialize(opToPad, opOperand, options); + if (failed(alreadyHasRequestedShape)) { return rewriter.notifyMatchFailure(opToPad, "--failed to compute padded shape"); + } - // Return the unpadded operand if padding to a static shape is not needed and + // Return the un-padded operand if padding to a static shape is not needed and // if the nofold flag is not set. bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size() ? bool(options.nofoldFlags[opOperand->getOperandNumber()]) : false; - if (!nofold && alreadyHasRequestedShape) + if (!nofold && *alreadyHasRequestedShape) return opOperand->get(); // Fail if `paddingValues` specifies no padding value. @@ -140,13 +226,18 @@ static FailureOr padOperandToSmallestStaticBoundingBox( opToPad.getLoc(), cast(paddingAttr)); } + // Computes the padded shape. + if (!*alreadyHasRequestedShape) + shape.computePadding(rewriter, opOperand->get()); + // Pad the operand to the bounding box defined by `paddedShape`. - auto paddedTensorType = RankedTensorType::get( - paddedShape, getElementTypeOrSelf(opOperand->get())); + RankedTensorType paddedTensorType = + shape.getType(getElementTypeOrSelf(opOperand->get())); LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " << paddedTensorType); return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType, - opOperand->get(), paddingValue, nofold); + opOperand->get(), paddingValue, nofold, + shape.dynDims); } LogicalResult diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 2527d90cfa2e6..209309ddb413a 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -244,11 +244,13 @@ bool isReductionIterator(utils::IteratorType iteratorType) { } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, - Value source, Value pad, bool nofold) { + Value source, Value pad, bool nofold, + ValueRange typeDynDims) { // Exit if `source` is not defined by an ExtractSliceOp. auto sliceOp = source.getDefiningOp(); if (!sliceOp) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Search the `source` use-def chain for padded LinalgOps. Value current = sliceOp.getSource(); @@ -264,24 +266,28 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, // Exit if the search fails to match a tensor::PadOp at the end of the matched // LinalgOp sequence. if (!padOp) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Exit if the padded result type does not match. if (sliceOp.getSource().getType() != type) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Exit if the LinalgOps are not high padded. if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { return getConstantIntValue(ofr) != static_cast(0); })) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Exit if `padOpSliceOp`, which defines the slice used by // `padOp`, is rank-reducing. auto padOpSliceOp = padOp.getSource().getDefiningOp(); if (!padOpSliceOp || sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size // of the slice padded by `padOp`. @@ -290,14 +296,16 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, [](std::tuple it) { return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); })) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Exit if the padding values do not match. Attribute padOpPadAttr, padAttr; Value padOpPad = padOp.getConstantPaddingValue(); if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) - return tensor::createPadHighOp(type, source, pad, nofold, loc, b); + return tensor::createPadHighOp(type, source, pad, nofold, loc, b, + typeDynDims); // Return the padded result if the padding values and sizes match. return sliceOp.getSource(); diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index ab2711545405e..ff3ec1625511b 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -300,7 +300,7 @@ func.func @negative_no_ub_estimate(%arg0: tensor, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{ailed to pad op}} + // expected-error @below {{failed to pad op}} %padded, %pad, %copy_back = transform.structured.pad %0 { padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], // Note - attempting to pad non-static dim @@ -416,6 +416,6 @@ module attributes {transform.with_named_sequence} { padding_dimensions=[0, 1, 2], nofold_flags=[1, 1, 1] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.yield + transform.yield } } From ebc79ee98de33a090f22cd6ad4018baeb010d78f Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 18 Jun 2025 14:49:02 +0000 Subject: [PATCH 2/2] add use_prescribed_tensor_shapes option --- .../Linalg/TransformOps/LinalgTransformOps.td | 10 ++++-- .../TransformOps/LinalgTransformOps.cpp | 36 ++++++++++++++++--- .../test/Dialect/Linalg/transform-op-pad.mlir | 35 ++++++++++++++++++ 3 files changed, 74 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 15ea5e7bf7159..6f6df350f1ba6 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1134,7 +1134,8 @@ def PadOp : Op, "{}">:$transpose_paddings, - DefaultValuedAttr:$copy_back_op); + DefaultValuedAttr:$copy_back_op, + DefaultValuedAttr:$use_prescribed_tensor_shapes); let results = (outs TransformHandleTypeInterface:$padded, TransformHandleTypeInterface:$pad, TransformHandleTypeInterface:$copy); @@ -1142,6 +1143,7 @@ def PadOp : Op($pad_to_multiple_of, $static_pad_to_multiple_of)^)? + (`use_prescribed_tensor_shapes` $use_prescribed_tensor_shapes^)? attr-dict `:` functional-type(operands, results) }]; @@ -1159,13 +1161,15 @@ def PadOp : Op", "{}">:$staticPadToMultipleOf, CArg<"ArrayRef", "{}">:$nofoldFlags, CArg<"ArrayRef", "{}">:$transposePaddings, - CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>, + CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp, + CArg<"bool", "false">:$usePrescribedTensorShapes)>, OpBuilder<(ins "Value":$target, "ArrayRef":$paddingDimensions, "ArrayRef":$mixedPadToMultipleOf, CArg<"ArrayRef", "{}">:$nofoldFlags, CArg<"ArrayRef", "{}">:$transposePaddings, - CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)> + CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp, + CArg<"bool", "false">:$usePrescribedTensorShapes)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index b2c28f5eed33c..d78c8847f8843 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, ArrayRef padToMultipleOf, ArrayRef nofoldFlags, ArrayRef transposePaddings, - StringRef copyBackOp) { + StringRef copyBackOp, + bool usePrescribedTensorShapes) { auto resultType = transform::AnyOpType::get(b.getContext()); return build(/*builder=*/b, /*result=*/result, @@ -1922,7 +1923,9 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, : b.getDenseI64ArrayAttr(padToMultipleOf)), /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags), /*transposePaddings=*/b.getArrayAttr(transposePaddings), - /*copyBackOp=*/b.getStringAttr(copyBackOp)); + /*copyBackOp=*/b.getStringAttr(copyBackOp), + /*usePrescribedTensorShapes=*/ + usePrescribedTensorShapes ? b.getUnitAttr() : nullptr); } void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, @@ -1930,7 +1933,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, ArrayRef mixedPadToMultipleOf, ArrayRef nofoldFlags, ArrayRef transposePaddings, - StringRef copyBackOp) { + StringRef copyBackOp, + bool usePrescribedTensorShapes) { auto resultType = transform::AnyOpType::get(b.getContext()); SmallVector staticPadToMultipleOf; SmallVector dynamicPadToMultipleOf; @@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target, /*padToMultipleOf=*/staticPadToMultipleOf, /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags), /*transposePaddings=*/b.getArrayAttr(transposePaddings), - /*copyBackOp=*/b.getStringAttr(copyBackOp)); + /*copyBackOp=*/copyBackOp, + /*usePrescribedTensorShapes=*/usePrescribedTensorShapes); } void PadOp::getEffects( @@ -2051,11 +2056,34 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, } else { llvm_unreachable("unsupported copy_back op"); } + // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand. + bool irChanged = false; + if (getUsePrescribedTensorShapes() && + linalgTarget.hasPureTensorSemantics()) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(linalgTarget); + for (OpOperand &operand : linalgTarget->getOpOperands()) { + for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) { + if (!ShapedType::isDynamic(dim)) + continue; + options.setSizeToPadTo(operand.getOperandNumber(), i, + tensor::getMixedSize(rewriter, + operand.get().getLoc(), + operand.get(), i)); + irChanged = true; + } + } + } SmallVector replacements; SmallVector newPadOps; if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp, replacements, newPadOps))) { + if (irChanged) { + auto diag = emitDefiniteFailure() << "failed to pad op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index ff3ec1625511b..bc684b53c9b61 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -313,6 +313,41 @@ module attributes {transform.with_named_sequence} { // ----- +// Test dynamic padding using `use_prescribed_tensor_shapes` + +// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)> +// CHECK: @use_prescribed_tensor_shapes +// CHECK: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<12x?xf32> +func.func @use_prescribed_tensor_shapes(%arg0: tensor, + %arg1: tensor<12x?xf32>, + %arg2: tensor) -> tensor { + // CHECK: %[[C1_0:.*]] = arith.constant 1 : index + // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG1]], %[[C1_0]] : tensor<12x?xf32> + // CHECK: %[[PADDING:.*]] = affine.apply #[[MAP]]()[%[[DIM_0]]] + // CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG1]] low[0, 0] high[0, %[[PADDING]]] { + // CHECK: linalg.matmul ins(%[[ARG0]], %[[PADDED]] : tensor, tensor<12x?xf32>) + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor<12x?xf32>) outs(%arg2 : tensor) -> tensor + func.return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad, %copy_back = transform.structured.pad %0 + pad_to_multiple_of [7] use_prescribed_tensor_shapes { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[1] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op + transform.yield + } +} + +// ----- + // Check that the padding can be applied even when the output argument of the // linalg op is not produced by an empty op or an extract_slice op.