Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1134,14 +1134,16 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
DefaultValuedAttr<
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
"{}">:$transpose_paddings,
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op,
DefaultValuedAttr<UnitAttr, "false">:$use_prescribed_tensor_shapes);
let results = (outs TransformHandleTypeInterface:$padded,
TransformHandleTypeInterface:$pad,
TransformHandleTypeInterface:$copy);

let assemblyFormat = [{
$target
(`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
(`use_prescribed_tensor_shapes` $use_prescribed_tensor_shapes^)?
attr-dict
`:` functional-type(operands, results)
}];
Expand All @@ -1159,13 +1161,15 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
CArg<"bool", "false">:$usePrescribedTensorShapes)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
"ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
CArg<"bool", "false">:$usePrescribedTensorShapes)>
];

let extraClassDeclaration = [{
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,23 @@ struct LinalgPaddingOptions {
padToMultipleOf.emplace(m.begin(), m.end());
return *this;
}
/// A mapping between an operand and shape dim, and a size for a padding
/// dimension. Each size is expected to be greater or equal than the
/// corresponding shape dim. If no value is provided then the constant upper
/// bound will be used.
DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;
LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex,
OpFoldResult size) {
assert(size && "expected non-null size");
sizeToPadTo[{operandIndex, dimIndex}] = size;
return *this;
}
/// Given the operand index and shape dim it returns the size to pad to.
OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const {
return sizeToPadTo.lookup_or(
std::pair<unsigned, unsigned>(operandIndex, dimIndex), nullptr);
}

/// A flag for every operand to mark the PadOp as nofold which enables
/// packing for statically shaped operands.
SmallVector<bool> nofoldFlags;
Expand Down
17 changes: 10 additions & 7 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType);
/// Check if iterator type has "reduction" semantics.
bool isReductionIterator(utils::IteratorType iteratorType);

/// Create a tensor::PadOp that pads `source` to the size of the statically
/// sized `type` whose static sizes are assumed to be greater than the dynamic
/// `source` size. The padding introduces trailing `pad` values until the
/// target size is met. If `source` is defined by one or more LinalgOps that
/// have been padded with the same value and sizes, return their padded result
/// instead of creating a tensor::PadOp.
/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes
/// are assumed to be greater than the dynamic `source` size. If `typeDynDims`
/// is specified, then it must contain the sizes of all the dynamic dimensions
/// in order of appearance in `type`, otherwise the function will pad those
/// values to `0`. The padding introduces trailing `pad` values until the target
/// size is met. If `source` is defined by one or more LinalgOps that have been
/// padded with the same value and sizes, return their padded result instead of
/// creating a tensor::PadOp.
///
/// Example:
/// ```
Expand All @@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType);
/// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst }
/// ```
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold);
Value source, Value padding, bool nofold,
ValueRange typeDynDims = std::nullopt);

/// Returns GenericOp that copies an n-D memref. Unlike the current
/// implementation of memref::CopyOp, this op can further tile, lower to loops
Expand Down
36 changes: 32 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> padToMultipleOf,
ArrayRef<int64_t> nofoldFlags,
ArrayRef<Attribute> transposePaddings,
StringRef copyBackOp) {
StringRef copyBackOp,
bool usePrescribedTensorShapes) {
auto resultType = transform::AnyOpType::get(b.getContext());
return build(/*builder=*/b,
/*result=*/result,
Expand All @@ -1922,15 +1923,18 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
: b.getDenseI64ArrayAttr(padToMultipleOf)),
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
/*copyBackOp=*/b.getStringAttr(copyBackOp));
/*copyBackOp=*/b.getStringAttr(copyBackOp),
/*usePrescribedTensorShapes=*/
usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
}

void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
ArrayRef<int64_t> nofoldFlags,
ArrayRef<Attribute> transposePaddings,
StringRef copyBackOp) {
StringRef copyBackOp,
bool usePrescribedTensorShapes) {
auto resultType = transform::AnyOpType::get(b.getContext());
SmallVector<int64_t> staticPadToMultipleOf;
SmallVector<Value> dynamicPadToMultipleOf;
Expand All @@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
/*padToMultipleOf=*/staticPadToMultipleOf,
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
/*copyBackOp=*/b.getStringAttr(copyBackOp));
/*copyBackOp=*/copyBackOp,
/*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
}

void PadOp::getEffects(
Expand Down Expand Up @@ -2051,11 +2056,34 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
} else {
llvm_unreachable("unsupported copy_back op");
}
// Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
bool irChanged = false;
if (getUsePrescribedTensorShapes() &&
linalgTarget.hasPureTensorSemantics()) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(linalgTarget);
for (OpOperand &operand : linalgTarget->getOpOperands()) {
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
if (!ShapedType::isDynamic(dim))
continue;
options.setSizeToPadTo(operand.getOperandNumber(), i,
tensor::getMixedSize(rewriter,
operand.get().getLoc(),
operand.get(), i));
irChanged = true;
}
}
}

SmallVector<Value> replacements;
SmallVector<tensor::PadOp> newPadOps;
if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
replacements, newPadOps))) {
if (irChanged) {
auto diag = emitDefiniteFailure() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
Expand Down
Loading