Skip to content

Commit a50f53f

Browse files
committed
address Max's comments
Signed-off-by: James Newling <[email protected]>
1 parent d7c6c3e commit a50f53f

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
620620
/// In the future, more general interfaces can be devised to encode similar
621621
/// shape evolutions and map between an op and its operands.
622622
SmallVector<OpFoldResult>
623-
computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
623+
computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
624624
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
625625
const PadTilingInterfaceOptions &options);
626626

@@ -630,17 +630,13 @@ using PadSizeComputationFunction =
630630
const PadTilingInterfaceOptions &)>;
631631

632632
/// Specific helper for Linalg ops.
633-
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
634-
OpBuilder &rewriter, OpOperand &operandToPad,
635-
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
633+
FailureOr<SmallVector<OpFoldResult>>
634+
computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
635+
ArrayRef<Range> iterationDomain,
636+
const PadTilingInterfaceOptions &);
636637

637-
/// Pad the iterator dimensions of `toPad`.
638-
/// * "options.paddingSizes" indicates that each padding dimension should be
639-
/// padded to the specified padding size.
640-
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
641-
// interpreted as the bounding box (dynamic) value to pad to.
642-
/// * Use "options.paddingValues" to set the padding value of the created
643-
// tensor::PadOp.
638+
/// Operations and values created in the process of padding a TilingInterface
639+
/// operation.
644640
struct PadTilingInterfaceResult {
645641
/// The operands of the padded op.
646642
SmallVector<tensor::PadOp> padOps;
@@ -649,6 +645,14 @@ struct PadTilingInterfaceResult {
649645
/// Slices of the padded op's results, same types as `toPad`.
650646
SmallVector<Value> replacements;
651647
};
648+
649+
/// Pad the iterator dimensions of `toPad`.
650+
/// * "options.paddingSizes" indicates that each padding dimension should be
651+
/// padded to the specified padding size.
652+
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
653+
// interpreted as the bounding box (dynamic) value to pad to.
654+
/// * Use "options.paddingValues" to set the padding value of the created
655+
// tensor::PadOp.
652656
FailureOr<PadTilingInterfaceResult>
653657
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
654658
PadTilingInterfaceOptions options,

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,14 +2469,11 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
24692469
diag.attachNote(target->getLoc()) << "target op";
24702470
return diag;
24712471
}
2472-
2473-
const auto &[paddedOperands, paddedOp, slicedResults] = *maybePadOps;
2472+
const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
24742473

24752474
// Set transform results.
24762475
paddedOps.push_back(paddedOp);
24772476
padOps.append(paddedOperands.begin(), paddedOperands.end());
2478-
2479-
// erase targetOp:
24802477
rewriter.replaceOp(targetOp.getOperation(), slicedResults);
24812478
}
24822479

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
9696
/// In the future, more general interfaces can be devised to encode similar
9797
/// shape evolutions and map between an op and its operands.
9898
SmallVector<OpFoldResult>
99-
linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
99+
linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v,
100100
AffineMap indexingMap,
101101
ArrayRef<OpFoldResult> indexingSizes,
102102
const PadTilingInterfaceOptions &options) {
@@ -110,7 +110,7 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
110110

111111
// "Full-rank" padding specification.
112112
SmallVector<OpFoldResult> paddingSizes =
113-
getFullRankPaddingSizes(rewriter, indexingSizes, options);
113+
getFullRankPaddingSizes(builder, indexingSizes, options);
114114

115115
// For each dimension in the operand's shape, iterate over indexingSizes and
116116
// add the various term contributions.
@@ -148,28 +148,27 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
148148
OpFoldResult paddingDimOfr;
149149
if (options.padToMultipleOf) {
150150
AffineExpr d0, s0;
151-
bindDims(rewriter.getContext(), d0);
152-
bindSymbols(rewriter.getContext(), s0);
151+
bindDims(builder.getContext(), d0);
152+
bindSymbols(builder.getContext(), s0);
153153
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
154154
AffineMap composedMap = projectedMap.compose(ceilMap);
155155
paddingDimOfr = affine::makeComposedFoldedAffineApply(
156-
rewriter, loc, composedMap,
157-
{indexingSizes[paddingDim], paddingSize},
156+
builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize},
158157
/*composeAffineMin=*/true);
159158
} else {
160159
// Otherwise just set to paddingSize.
161160
paddingDimOfr = affine::makeComposedFoldedAffineApply(
162-
rewriter, loc, projectedMap, paddingSize);
161+
builder, loc, projectedMap, paddingSize);
163162
}
164163

165164
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
166165
// multiplier.
167166
AffineExpr d0;
168-
bindDims(rewriter.getContext(), d0);
167+
bindDims(builder.getContext(), d0);
169168
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
170169
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
171170
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
172-
rewriter, loc, subtractMap, {paddingDimOfr});
171+
builder, loc, subtractMap, {paddingDimOfr});
173172
terms.push_back(maxAccessIdx);
174173

175174
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
@@ -178,19 +177,19 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
178177
// If there are no terms, just return the dim.
179178
if (terms.empty()) {
180179
paddedShape[resultIndex] =
181-
createFoldedDimOp(rewriter, loc, v, resultIndex);
180+
createFoldedDimOp(builder, loc, v, resultIndex);
182181
continue;
183182
}
184183

185184
// Sum individual terms' contributions.
186185
SmallVector<AffineExpr> dims(terms.size());
187-
bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
186+
bindDimsList(builder.getContext(), MutableArrayRef{dims});
188187
AffineExpr sumExpr = dims.front();
189188
for (unsigned i = 1; i < dims.size(); ++i)
190189
sumExpr = sumExpr + dims[i];
191190
// Add 1 to the maximum accessed index and get the final padded size.
192-
OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
193-
rewriter, loc, sumExpr + 1, terms);
191+
OpFoldResult paddedDimOfr =
192+
affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms);
194193
paddedShape[resultIndex] = paddedDimOfr;
195194
}
196195

@@ -199,17 +198,17 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
199198

200199
FailureOr<SmallVector<OpFoldResult>>
201200
linalg::computeIndexingMapOpInterfacePaddedShape(
202-
OpBuilder &rewriter, OpOperand &operandToPad,
201+
OpBuilder &builder, OpOperand &operandToPad,
203202
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
204203
auto transferOp =
205204
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
206205
if (!transferOp)
207206
return failure();
208207

209208
// clang-format off
210-
assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
211-
return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
212-
r.stride == OpFoldResult(rewriter.getIndexAttr(1));
209+
assert(llvm::all_of(iterationDomain, [&builder](Range r) {
210+
return r.offset == OpFoldResult(builder.getIndexAttr(0)) &&
211+
r.stride == OpFoldResult(builder.getIndexAttr(1));
213212
}) && "expected 0-offset 1-stride loop ranges");
214213
// clang-format on
215214
SmallVector<OpFoldResult> loopUpperBounds;
@@ -219,29 +218,29 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
219218

220219
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
221220
return computePaddedShape(
222-
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
221+
builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
223222
indexingMap, loopUpperBounds, options);
224223
}
225224

226225
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
227226
/// Value.
228-
static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
227+
static Value padOperand(OpBuilder &builder, TilingInterface opToPad,
229228
TypedValue<RankedTensorType> v,
230229
ArrayRef<OpFoldResult> paddedShape,
231230
Attribute paddingValueAttr) {
232231
Value paddingValue;
233232
if (auto complexTy =
234233
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
235234
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
236-
paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
235+
paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(),
237236
complexTy, complexAttr);
238237
}
239238
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
240-
paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
239+
paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(),
241240
getElementTypeOrSelf(v.getType()));
242241
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
243242
paddingValue =
244-
arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
243+
arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr);
245244
}
246245
assert(paddingValue && "failed to create value from padding attribute");
247246

@@ -260,7 +259,7 @@ static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
260259
RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
261260
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
262261
<< paddedTensorType);
263-
return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
262+
return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v,
264263
paddingValue, /*nofold=*/false, dynDims);
265264
}
266265

0 commit comments

Comments
 (0)