88
99#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1010
11+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
1112#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
1213#include " mlir/Dialect/Complex/IR/Complex.h"
1314#include " mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,53 +23,93 @@ using namespace mlir::linalg;
2223#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
2324#define DBGSNL () (llvm::dbgs() << " \n " )
2425
25- // / Compute the padded shape of the given operand. The operand is padded to a
26- // / static bounding box according to the specified padding options.
27- static LogicalResult computePaddedShape (linalg::LinalgOp opToPad,
26+ namespace {
27+ // / Helper class for storing padding information.
28+ struct PaddingInfo {
29+ PaddingInfo (int64_t padToMultipleOf = 1 , OpFoldResult size = {})
30+ : padToMultipleOf(padToMultipleOf), size(size) {}
31+ // / Pad the tensor to a multiple of.
32+ int64_t padToMultipleOf = 1 ;
33+ // / The size used for padding.
34+ OpFoldResult size = {};
35+ };
36+
37+ // / Helper class for storing and computing the padded shape.
38+ struct PaddedShape {
39+ // / Initializes the shape information and on success it returns whether the
40+ // / shape of the operand will change. Returns failure if the operand cannot be
41+ // / padded.
42+ FailureOr<bool > initialize (linalg::LinalgOp opToPad, OpOperand *opOperand,
43+ const LinalgPaddingOptions &options);
44+
45+ // / Computs the padded shape.
46+ void computePadding (OpBuilder &builder, Value operand);
47+
48+ // / Returns the new tensor type.
49+ RankedTensorType getType (Type elemTy) {
50+ return RankedTensorType::get (shape, elemTy);
51+ }
52+
53+ SmallVector<Value> dynDims;
54+
55+ private:
56+ SmallVector<int64_t > shape;
57+ DenseMap<int64_t , PaddingInfo> dimToInfo;
58+ };
59+ } // namespace
60+
61+ FailureOr<bool > PaddedShape::initialize (linalg::LinalgOp opToPad,
2862 OpOperand *opOperand,
29- const LinalgPaddingOptions &options,
30- SmallVector<int64_t > &paddedShape,
31- bool &alreadyHasRequestedShape) {
63+ const LinalgPaddingOptions &options) {
3264 AffineMap indexingMap = opToPad.getMatchingIndexingMap (opOperand);
33- ArrayRef<int64_t > shape = opToPad.getShape (opOperand);
65+
66+ // Initialize the padded shape.
67+ llvm::append_range (shape, opToPad.getShape (opOperand));
3468
3569 // Collect the shape dimensions that are a function of "paddingDimensions",
3670 // along with the multiple that they should be padded to ("1" if none).
37- alreadyHasRequestedShape = true ;
38- DenseMap<int64_t , int64_t > shapeDimToMultiple;
71+ bool alreadyHasRequestedShape = true ;
3972 for (const auto &dimEn : enumerate(options.paddingDimensions )) {
4073 for (const auto &en : enumerate(indexingMap.getResults ())) {
4174 if (en.value ().isFunctionOfDim (dimEn.value ())) {
75+ PaddingInfo paddingInfo;
4276 int64_t dimSize = shape[en.index ()];
4377 if (options.padToMultipleOf .has_value ()) {
44- shapeDimToMultiple[en. index ()] =
78+ paddingInfo. padToMultipleOf =
4579 (*options.padToMultipleOf )[dimEn.index ()];
4680 } else {
47- shapeDimToMultiple[en. index ()] = 1 ;
81+ paddingInfo. padToMultipleOf = 1 ;
4882 }
49- if (ShapedType::isDynamic (dimSize)) {
50- alreadyHasRequestedShape = false ;
51- } else if (dimSize % shapeDimToMultiple[en.index ()] != 0 ) {
83+
84+ // Check if the user provided a size in the options.
85+ paddingInfo.size =
86+ options.getSizeToPadTo (opOperand->getOperandNumber (), en.index ());
87+
88+ // Set the padding info.
89+ dimToInfo[en.index ()] = paddingInfo;
90+ if (ShapedType::isDynamic (dimSize) ||
91+ dimSize % paddingInfo.padToMultipleOf != 0 ||
92+ !paddingInfo.size .isNull ()) {
5293 alreadyHasRequestedShape = false ;
5394 }
5495 }
5596 }
5697 }
5798
58- // Helper function to round a number up to a given multiple.
59- auto ceil = [](int64_t val, int64_t multiple) {
60- return ((val + multiple - 1 ) / multiple) * multiple;
61- };
62-
6399 // Upper bound the sizes to obtain a static bounding box.
64- paddedShape.assign (shape.begin (), shape.end ());
65100 for (int64_t i = 0 , e = shape.size (); i < e; ++i) {
66- LLVM_DEBUG (DBGS () << " --compute padded size for dim " << i << " \n " );
101+ LLVM_DEBUG (DBGS () << " --computing un- padded size for dim " << i << " \n " );
67102 // Skip dimensions that do not require padding.
68- if (!shapeDimToMultiple .contains (i)) {
103+ if (!dimToInfo .contains (i)) {
69104 LLVM_DEBUG (DBGS () << " ----dim does not require padding, SKIP\n " );
70105 continue ;
71106 }
107+ PaddingInfo &info = dimToInfo[i];
108+ if (info.size ) {
109+ LLVM_DEBUG (DBGS () << " ----the user provided the size: " << info.size
110+ << " \n " );
111+ continue ;
112+ }
72113 // Otherwise, try to compute a constant upper bound for the size value.
73114 FailureOr<int64_t > upperBound =
74115 ValueBoundsConstraintSet::computeConstantBound (
@@ -77,14 +118,58 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
77118 /* dim=*/ i},
78119 /* stopCondition=*/ nullptr , /* closedUB=*/ true );
79120 if (failed (upperBound)) {
80- LLVM_DEBUG (DBGS () << " ----could not compute a bounding box for padding" );
121+ LLVM_DEBUG (
122+ DBGS () << " ----could not compute a bounding box for padding\n " );
81123 return failure ();
82124 }
83- paddedShape[i] = ceil (*upperBound, shapeDimToMultiple[i]);
84- LLVM_DEBUG (DBGS () << " ----new dim size: " << paddedShape[i] << " \n " );
125+ info.size =
126+ IntegerAttr::get (IndexType::get (opToPad.getContext ()), *upperBound);
127+ LLVM_DEBUG (DBGS () << " ----new un-padded size: " << info.size << " \n " );
85128 }
129+ return alreadyHasRequestedShape;
130+ }
86131
87- return success ();
132+ void PaddedShape::computePadding (OpBuilder &builder, Value operand) {
133+ Location loc = operand.getLoc ();
134+ AffineExpr sizeSym = builder.getAffineSymbolExpr (0 );
135+
136+ // Compute the padding for each dimension.
137+ for (auto &&[i, dim] : llvm::enumerate (shape)) {
138+ LLVM_DEBUG (DBGS () << " --computing padded size for dim " << i << " \n " );
139+
140+ // Get the padding info or default info for the shape dimension.
141+ PaddingInfo paddingInfo = dimToInfo.lookup (i);
142+
143+ // Skip dimensions that do not require padding.
144+ if (paddingInfo.size .isNull ()) {
145+ LLVM_DEBUG (DBGS () << " ----dim does not require padding, SKIP\n " );
146+
147+ // We still need to push the size as `makeComposedPadHighOp` expects a
148+ // range with all the dynamic sizes, whether they're being padded or not.
149+ if (ShapedType::isDynamic (dim)) {
150+ dynDims.push_back (
151+ cast<Value>(tensor::getMixedSize (builder, loc, operand, i)));
152+ }
153+ continue ;
154+ }
155+
156+ // Compute the padded size to be a multiple of `padToMultipleOf`.
157+ AffineExpr szExpr = (sizeSym).ceilDiv (paddingInfo.padToMultipleOf ) *
158+ paddingInfo.padToMultipleOf ;
159+ OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply (
160+ builder, loc, szExpr, paddingInfo.size );
161+ assert (paddedSize && " invalid arguments to affine apply" );
162+
163+ if (auto cstSzAttr = dyn_cast<Attribute>(paddedSize)) {
164+ // Update the shape as the size is static.
165+ dim = cast<IntegerAttr>(cstSzAttr).getValue ().getZExtValue ();
166+ } else {
167+ // Add a dynamic dimension.
168+ dim = ShapedType::kDynamic ;
169+ dynDims.push_back (cast<Value>(paddedSize));
170+ }
171+ LLVM_DEBUG (DBGS () << " ----new dim size: " << paddedSize << " \n " );
172+ }
88173}
89174
90175// / Pad the `opOperand` in the "paddingDimensions" using the padding value and
@@ -107,20 +192,21 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
107192 options.padToMultipleOf ->size () == options.paddingDimensions .size ()) &&
108193 " invalid number of elements in padToMultipleOf" );
109194
110- // Compute padded shape.
111- SmallVector< int64_t > paddedShape ;
112- bool alreadyHasRequestedShape = false ;
113- if ( failed ( computePaddedShape ( opToPad, opOperand, options, paddedShape,
114- alreadyHasRequestedShape)))
195+ // Initialize the padded shape and get whether it requires padding .
196+ PaddedShape shape ;
197+ FailureOr< bool > alreadyHasRequestedShape =
198+ shape. initialize ( opToPad, opOperand, options);
199+ if ( failed ( alreadyHasRequestedShape)) {
115200 return rewriter.notifyMatchFailure (opToPad,
116201 " --failed to compute padded shape" );
202+ }
117203
118- // Return the unpadded operand if padding to a static shape is not needed and
204+ // Return the un-padded operand if padding to a static shape is not needed and
119205 // if the nofold flag is not set.
120206 bool nofold = opOperand->getOperandNumber () < options.nofoldFlags .size ()
121207 ? bool (options.nofoldFlags [opOperand->getOperandNumber ()])
122208 : false ;
123- if (!nofold && alreadyHasRequestedShape)
209+ if (!nofold && * alreadyHasRequestedShape)
124210 return opOperand->get ();
125211
126212 // Fail if `paddingValues` specifies no padding value.
@@ -140,13 +226,18 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
140226 opToPad.getLoc (), cast<TypedAttr>(paddingAttr));
141227 }
142228
229+ // Computes the padded shape.
230+ if (!*alreadyHasRequestedShape)
231+ shape.computePadding (rewriter, opOperand->get ());
232+
143233 // Pad the operand to the bounding box defined by `paddedShape`.
144- auto paddedTensorType = RankedTensorType::get (
145- paddedShape, getElementTypeOrSelf (opOperand->get ()));
234+ RankedTensorType paddedTensorType =
235+ shape. getType ( getElementTypeOrSelf (opOperand->get ()));
146236 LLVM_DEBUG (DBGS () << " --SUCCESS, makeComposedPadHighOp with type: "
147237 << paddedTensorType);
148238 return makeComposedPadHighOp (rewriter, opToPad->getLoc (), paddedTensorType,
149- opOperand->get (), paddingValue, nofold);
239+ opOperand->get (), paddingValue, nofold,
240+ shape.dynDims );
150241}
151242
152243LogicalResult
0 commit comments