Skip to content

Commit 26fd8bd

Browse files
committed
[mlir][linalg] Add sizeToPadTo option to linalg::LinalgPaddingOptions
1 parent acde20b commit 26fd8bd

File tree

5 files changed

+172
-53
lines changed

5 files changed

+172
-53
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,23 @@ struct LinalgPaddingOptions {
295295
padToMultipleOf.emplace(m.begin(), m.end());
296296
return *this;
297297
}
298+
/// A mapping between an operand and shape dim, and a size for a padding
299+
/// dimension. Each size is expected to be greater or equal than the
300+
/// corresponding shape dim. If no value is provided then the constant upper
301+
/// bound will be used.
302+
DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;
303+
LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex,
304+
OpFoldResult size) {
305+
assert(size && "expected non-null size");
306+
sizeToPadTo[{operandIndex, dimIndex}] = size;
307+
return *this;
308+
}
309+
/// Given the operand index and shape dim it returns the size to pad to.
310+
OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const {
311+
return sizeToPadTo.lookup_or(
312+
std::pair<unsigned, unsigned>(operandIndex, dimIndex), nullptr);
313+
}
314+
298315
/// A flag for every operand to mark the PadOp as nofold which enables
299316
/// packing for statically shaped operands.
300317
SmallVector<bool> nofoldFlags;

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType);
7171
/// Check if iterator type has "reduction" semantics.
7272
bool isReductionIterator(utils::IteratorType iteratorType);
7373

74-
/// Create a tensor::PadOp that pads `source` to the size of the statically
75-
/// sized `type` whose static sizes are assumed to be greater than the dynamic
76-
/// `source` size. The padding introduces trailing `pad` values until the
77-
/// target size is met. If `source` is defined by one or more LinalgOps that
78-
/// have been padded with the same value and sizes, return their padded result
79-
/// instead of creating a tensor::PadOp.
74+
/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes
75+
/// are assumed to be greater than the dynamic `source` size. If `typeDynDims`
76+
/// is specified, then it must contain the sizes of all the dynamic dimensions
77+
/// in order of appearance in `type`, otherwise the function will pad those
78+
/// values to `0`. The padding introduces trailing `pad` values until the target
79+
/// size is met. If `source` is defined by one or more LinalgOps that have been
80+
/// padded with the same value and sizes, return their padded result instead of
81+
/// creating a tensor::PadOp.
8082
///
8183
/// Example:
8284
/// ```
@@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType);
9193
/// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst }
9294
/// ```
9395
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
94-
Value source, Value pad, bool nofold);
96+
Value source, Value padding, bool nofold,
97+
ValueRange typeDynDims = std::nullopt);
9598

9699
/// Returns GenericOp that copies an n-D memref. Unlike the current
97100
/// implementation of memref::CopyOp, this op can further tile, lower to loops

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 127 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1213
#include "mlir/Dialect/Complex/IR/Complex.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,53 +23,93 @@ using namespace mlir::linalg;
2223
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
2324
#define DBGSNL() (llvm::dbgs() << "\n")
2425

25-
/// Compute the padded shape of the given operand. The operand is padded to a
26-
/// static bounding box according to the specified padding options.
27-
static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
26+
namespace {
27+
/// Helper class for storing padding information.
28+
struct PaddingInfo {
29+
PaddingInfo(int64_t padToMultipleOf = 1, OpFoldResult size = {})
30+
: padToMultipleOf(padToMultipleOf), size(size) {}
31+
/// Pad the tensor to a multiple of.
32+
int64_t padToMultipleOf = 1;
33+
/// The size used for padding.
34+
OpFoldResult size = {};
35+
};
36+
37+
/// Helper class for storing and computing the padded shape.
38+
struct PaddedShape {
39+
/// Initializes the shape information and on success it returns whether the
40+
/// shape of the operand will change. Returns failure if the operand cannot be
41+
/// padded.
42+
FailureOr<bool> initialize(linalg::LinalgOp opToPad, OpOperand *opOperand,
43+
const LinalgPaddingOptions &options);
44+
45+
/// Computs the padded shape.
46+
void computePadding(OpBuilder &builder, Value operand);
47+
48+
/// Returns the new tensor type.
49+
RankedTensorType getType(Type elemTy) {
50+
return RankedTensorType::get(shape, elemTy);
51+
}
52+
53+
SmallVector<Value> dynDims;
54+
55+
private:
56+
SmallVector<int64_t> shape;
57+
DenseMap<int64_t, PaddingInfo> dimToInfo;
58+
};
59+
} // namespace
60+
61+
FailureOr<bool> PaddedShape::initialize(linalg::LinalgOp opToPad,
2862
OpOperand *opOperand,
29-
const LinalgPaddingOptions &options,
30-
SmallVector<int64_t> &paddedShape,
31-
bool &alreadyHasRequestedShape) {
63+
const LinalgPaddingOptions &options) {
3264
AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
33-
ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
65+
66+
// Initialize the padded shape.
67+
llvm::append_range(shape, opToPad.getShape(opOperand));
3468

3569
// Collect the shape dimensions that are a function of "paddingDimensions",
3670
// along with the multiple that they should be padded to ("1" if none).
37-
alreadyHasRequestedShape = true;
38-
DenseMap<int64_t, int64_t> shapeDimToMultiple;
71+
bool alreadyHasRequestedShape = true;
3972
for (const auto &dimEn : enumerate(options.paddingDimensions)) {
4073
for (const auto &en : enumerate(indexingMap.getResults())) {
4174
if (en.value().isFunctionOfDim(dimEn.value())) {
75+
PaddingInfo paddingInfo;
4276
int64_t dimSize = shape[en.index()];
4377
if (options.padToMultipleOf.has_value()) {
44-
shapeDimToMultiple[en.index()] =
78+
paddingInfo.padToMultipleOf =
4579
(*options.padToMultipleOf)[dimEn.index()];
4680
} else {
47-
shapeDimToMultiple[en.index()] = 1;
81+
paddingInfo.padToMultipleOf = 1;
4882
}
49-
if (ShapedType::isDynamic(dimSize)) {
50-
alreadyHasRequestedShape = false;
51-
} else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
83+
84+
// Check if the user provided a size in the options.
85+
paddingInfo.size =
86+
options.getSizeToPadTo(opOperand->getOperandNumber(), en.index());
87+
88+
// Set the padding info.
89+
dimToInfo[en.index()] = paddingInfo;
90+
if (ShapedType::isDynamic(dimSize) ||
91+
dimSize % paddingInfo.padToMultipleOf != 0 ||
92+
!paddingInfo.size.isNull()) {
5293
alreadyHasRequestedShape = false;
5394
}
5495
}
5596
}
5697
}
5798

58-
// Helper function to round a number up to a given multiple.
59-
auto ceil = [](int64_t val, int64_t multiple) {
60-
return ((val + multiple - 1) / multiple) * multiple;
61-
};
62-
6399
// Upper bound the sizes to obtain a static bounding box.
64-
paddedShape.assign(shape.begin(), shape.end());
65100
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
66-
LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
101+
LLVM_DEBUG(DBGS() << "--computing un-padded size for dim " << i << "\n");
67102
// Skip dimensions that do not require padding.
68-
if (!shapeDimToMultiple.contains(i)) {
103+
if (!dimToInfo.contains(i)) {
69104
LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
70105
continue;
71106
}
107+
PaddingInfo &info = dimToInfo[i];
108+
if (info.size) {
109+
LLVM_DEBUG(DBGS() << "----the user provided the size: " << info.size
110+
<< "\n");
111+
continue;
112+
}
72113
// Otherwise, try to compute a constant upper bound for the size value.
73114
FailureOr<int64_t> upperBound =
74115
ValueBoundsConstraintSet::computeConstantBound(
@@ -77,14 +118,58 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
77118
/*dim=*/i},
78119
/*stopCondition=*/nullptr, /*closedUB=*/true);
79120
if (failed(upperBound)) {
80-
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
121+
LLVM_DEBUG(
122+
DBGS() << "----could not compute a bounding box for padding\n");
81123
return failure();
82124
}
83-
paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
84-
LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
125+
info.size =
126+
IntegerAttr::get(IndexType::get(opToPad.getContext()), *upperBound);
127+
LLVM_DEBUG(DBGS() << "----new un-padded size: " << info.size << "\n");
85128
}
129+
return alreadyHasRequestedShape;
130+
}
86131

87-
return success();
132+
void PaddedShape::computePadding(OpBuilder &builder, Value operand) {
133+
Location loc = operand.getLoc();
134+
AffineExpr sizeSym = builder.getAffineSymbolExpr(0);
135+
136+
// Compute the padding for each dimension.
137+
for (auto &&[i, dim] : llvm::enumerate(shape)) {
138+
LLVM_DEBUG(DBGS() << "--computing padded size for dim " << i << "\n");
139+
140+
// Get the padding info or default info for the shape dimension.
141+
PaddingInfo paddingInfo = dimToInfo.lookup(i);
142+
143+
// Skip dimensions that do not require padding.
144+
if (paddingInfo.size.isNull()) {
145+
LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
146+
147+
// We still need to push the size as `makeComposedPadHighOp` expects a
148+
// range with all the dynamic sizes, whether they're being padded or not.
149+
if (ShapedType::isDynamic(dim)) {
150+
dynDims.push_back(
151+
cast<Value>(tensor::getMixedSize(builder, loc, operand, i)));
152+
}
153+
continue;
154+
}
155+
156+
// Compute the padded size to be a multiple of `padToMultipleOf`.
157+
AffineExpr szExpr = (sizeSym).ceilDiv(paddingInfo.padToMultipleOf) *
158+
paddingInfo.padToMultipleOf;
159+
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
160+
builder, loc, szExpr, paddingInfo.size);
161+
assert(paddedSize && "invalid arguments to affine apply");
162+
163+
if (auto cstSzAttr = dyn_cast<Attribute>(paddedSize)) {
164+
// Update the shape as the size is static.
165+
dim = cast<IntegerAttr>(cstSzAttr).getValue().getZExtValue();
166+
} else {
167+
// Add a dynamic dimension.
168+
dim = ShapedType::kDynamic;
169+
dynDims.push_back(cast<Value>(paddedSize));
170+
}
171+
LLVM_DEBUG(DBGS() << "----new dim size: " << paddedSize << "\n");
172+
}
88173
}
89174

90175
/// Pad the `opOperand` in the "paddingDimensions" using the padding value and
@@ -107,20 +192,21 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
107192
options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
108193
"invalid number of elements in padToMultipleOf");
109194

110-
// Compute padded shape.
111-
SmallVector<int64_t> paddedShape;
112-
bool alreadyHasRequestedShape = false;
113-
if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
114-
alreadyHasRequestedShape)))
195+
// Initialize the padded shape and get whether it requires padding.
196+
PaddedShape shape;
197+
FailureOr<bool> alreadyHasRequestedShape =
198+
shape.initialize(opToPad, opOperand, options);
199+
if (failed(alreadyHasRequestedShape)) {
115200
return rewriter.notifyMatchFailure(opToPad,
116201
"--failed to compute padded shape");
202+
}
117203

118-
// Return the unpadded operand if padding to a static shape is not needed and
204+
// Return the un-padded operand if padding to a static shape is not needed and
119205
// if the nofold flag is not set.
120206
bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size()
121207
? bool(options.nofoldFlags[opOperand->getOperandNumber()])
122208
: false;
123-
if (!nofold && alreadyHasRequestedShape)
209+
if (!nofold && *alreadyHasRequestedShape)
124210
return opOperand->get();
125211

126212
// Fail if `paddingValues` specifies no padding value.
@@ -140,13 +226,18 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
140226
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
141227
}
142228

229+
// Computes the padded shape.
230+
if (!*alreadyHasRequestedShape)
231+
shape.computePadding(rewriter, opOperand->get());
232+
143233
// Pad the operand to the bounding box defined by `paddedShape`.
144-
auto paddedTensorType = RankedTensorType::get(
145-
paddedShape, getElementTypeOrSelf(opOperand->get()));
234+
RankedTensorType paddedTensorType =
235+
shape.getType(getElementTypeOrSelf(opOperand->get()));
146236
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
147237
<< paddedTensorType);
148238
return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
149-
opOperand->get(), paddingValue, nofold);
239+
opOperand->get(), paddingValue, nofold,
240+
shape.dynDims);
150241
}
151242

152243
LogicalResult

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
244244
}
245245

246246
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
247-
Value source, Value pad, bool nofold) {
247+
Value source, Value pad, bool nofold,
248+
ValueRange typeDynDims) {
248249
// Exit if `source` is not defined by an ExtractSliceOp.
249250
auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
250251
if (!sliceOp)
251-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
252+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
253+
typeDynDims);
252254

253255
// Search the `source` use-def chain for padded LinalgOps.
254256
Value current = sliceOp.getSource();
@@ -264,24 +266,28 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
264266
// Exit if the search fails to match a tensor::PadOp at the end of the matched
265267
// LinalgOp sequence.
266268
if (!padOp)
267-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
269+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
270+
typeDynDims);
268271

269272
// Exit if the padded result type does not match.
270273
if (sliceOp.getSource().getType() != type)
271-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
274+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
275+
typeDynDims);
272276

273277
// Exit if the LinalgOps are not high padded.
274278
if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
275279
return getConstantIntValue(ofr) != static_cast<int64_t>(0);
276280
}))
277-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
281+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
282+
typeDynDims);
278283

279284
// Exit if `padOpSliceOp`, which defines the slice used by
280285
// `padOp`, is rank-reducing.
281286
auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
282287
if (!padOpSliceOp ||
283288
sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
284-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
289+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
290+
typeDynDims);
285291

286292
// Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
287293
// of the slice padded by `padOp`.
@@ -290,14 +296,16 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
290296
[](std::tuple<OpFoldResult, OpFoldResult> it) {
291297
return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
292298
}))
293-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
299+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
300+
typeDynDims);
294301

295302
// Exit if the padding values do not match.
296303
Attribute padOpPadAttr, padAttr;
297304
Value padOpPad = padOp.getConstantPaddingValue();
298305
if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
299306
!matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
300-
return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
307+
return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
308+
typeDynDims);
301309

302310
// Return the padded result if the padding values and sizes match.
303311
return sliceOp.getSource();

mlir/test/Dialect/Linalg/transform-op-pad.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ func.func @negative_no_ub_estimate(%arg0: tensor<?x12xf32>,
300300
module attributes {transform.with_named_sequence} {
301301
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
302302
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
303-
// expected-error @below {{ailed to pad op}}
303+
// expected-error @below {{failed to pad op}}
304304
%padded, %pad, %copy_back = transform.structured.pad %0 {
305305
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
306306
// Note - attempting to pad non-static dim
@@ -416,6 +416,6 @@ module attributes {transform.with_named_sequence} {
416416
padding_dimensions=[0, 1, 2],
417417
nofold_flags=[1, 1, 1]
418418
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
419-
transform.yield
419+
transform.yield
420420
}
421421
}

0 commit comments

Comments
 (0)