Skip to content

Commit 212e86d

Browse files
authored
[Codegen] Refactor CombineLayoutTransformation to use patterns (#21592)
Refactors the CombineLayoutTransformation pass to use pattern rewrites for folding relayout ops into map_scatter. Having patterns for these rewrites allows us to run relayout op folding patterns simultaneously with reshape and data-layout propagation patterns. This PR does not actually do this yet, in order to reduce the scope of the PR. The changes are effectively NFC, with the exception of adding additional relayout op foldings to `populateCombineRelayoutOpPatterns`, so no tests need to be changed. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 9323cd2 commit 212e86d

File tree

4 files changed

+174
-167
lines changed

4 files changed

+174
-167
lines changed

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 107 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,63 @@ static MapScatterOp foldTransposeIntoMapScatter(RewriterBase &rewriter,
112112
return mapScatterOp;
113113
}
114114

115+
/// Fold a tensor::ExpandShapeOp or tensor::CollapseShapeOp into a consumer
116+
/// `mapScatterOp`, by linearizing and then delinearizing the source indices
117+
/// of the `mapScatterOp`s index transformation.
118+
template <typename ReshapeOpTy>
119+
static IREE::LinalgExt::MapScatterOp
120+
foldReshapeIntoMapScatter(RewriterBase &rewriter, ReshapeOpTy reshapeOp,
121+
IREE::LinalgExt::MapScatterOp mapScatterOp) {
122+
assert(mapScatterOp.getInput() == reshapeOp->getResult(0) &&
123+
"expected reshapeOp to be the producer of mapScatterOp");
124+
Location loc = reshapeOp->getLoc();
125+
OpBuilder::InsertionGuard g(rewriter);
126+
rewriter.setInsertionPointAfter(reshapeOp);
127+
SmallVector<OpFoldResult> srcDims =
128+
tensor::getMixedSizes(rewriter, loc, reshapeOp.getSrc());
129+
// There can be leftover tensor.dim ops consuming the result of the reshape,
130+
// but they are expected to be folded into some affine.apply ops on the source
131+
// sizes by later cleanup patterns.
132+
SmallVector<OpFoldResult> resultDims =
133+
tensor::getMixedSizes(rewriter, loc, reshapeOp.getResult());
134+
135+
auto indexTransformBuilder =
136+
[&](ArrayRef<BlockArgument> srcIndices) -> SmallVector<Value> {
137+
auto linearizeIndexOp = rewriter.create<affine::AffineLinearizeIndexOp>(
138+
mapScatterOp->getLoc(), srcIndices, srcDims, /*disjoint=*/true);
139+
auto delinearizeIndexOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
140+
mapScatterOp->getLoc(), linearizeIndexOp.getResult(), resultDims,
141+
/*hasOuterBound=*/true);
142+
return delinearizeIndexOp->getResults();
143+
};
144+
rewriter.modifyOpInPlace(mapScatterOp, [&]() {
145+
mapScatterOp.insertTransformationAtStart(rewriter, indexTransformBuilder,
146+
srcDims.size());
147+
mapScatterOp.getInputMutable().assign(reshapeOp->getOperand(0));
148+
});
149+
return mapScatterOp;
150+
}
151+
152+
/// Fold a tensor::ExpandShapeOp into a consumer `mapScatterOp`, by linearizing
153+
/// and then delinearizing the source indices of the `mapScatterOp`s index
154+
/// transformation.
155+
static MapScatterOp
156+
foldExpandShapeIntoMapScatter(RewriterBase &rewriter,
157+
tensor::ExpandShapeOp expandShapeOp,
158+
MapScatterOp mapScatterOp) {
159+
return foldReshapeIntoMapScatter(rewriter, expandShapeOp, mapScatterOp);
160+
}
161+
162+
/// Fold a tensor::CollapseShapeOp into a consumer `mapScatterOp`, by
163+
/// linearizing and then delinearizing the source indices of the
164+
/// `mapScatterOp`s index transformation.
165+
static MapScatterOp
166+
foldCollapseShapeIntoMapScatter(RewriterBase &rewriter,
167+
tensor::CollapseShapeOp collapseShapeOp,
168+
MapScatterOp mapScatterOp) {
169+
return foldReshapeIntoMapScatter(rewriter, collapseShapeOp, mapScatterOp);
170+
}
171+
115172
/// Fold an `extractSliceOp` into a consumer `mapScatterOp` by applying a mask
116173
/// based on the bounds of the extractSliceOp. Currently, only zero offsets and
117174
/// unit strides are supported.
@@ -219,13 +276,7 @@ static void buildNestedDistributionLoops(
219276
});
220277
}
221278

222-
/// Fold a tensor.pad op into a iree_linalg_ext.map_scatter op, and separate
223-
/// the writing of padding values into a separate operation on the buffer that
224-
/// the map_scatter op is ultimately written into. The result buffer is taken
225-
/// from the direct consumer of the `mapScatterOp`, which is expected to be an
226-
/// `iree_codegen.store_to_buffer` op. Return failure if the result buffer is
227-
/// not found.
228-
static FailureOr<MapScatterOp>
279+
FailureOr<MapScatterOp>
229280
foldPadIntoMapScatter(RewriterBase &rewriter, tensor::PadOp padOp,
230281
MapScatterOp mapScatterOp,
231282
PadDistributionConfigFn padDistributionConfigFn) {
@@ -316,14 +367,9 @@ foldPadIntoMapScatter(RewriterBase &rewriter, tensor::PadOp padOp,
316367
return mapScatterOp;
317368
}
318369

319-
/// Fold the `op` into the `mapScatterOp`, if possible. The resulting
320-
/// map_scatter op is returned, if the `op` was folded. Otherwise, return
321-
/// failure. For `PadOp`s, use the `padDistributionConfigFn` to distribute
322-
/// the writing of padding values to the corresponding output buffer.
323-
static FailureOr<MapScatterOp>
324-
foldIntoMapScatter(RewriterBase &rewriter, Operation *op,
325-
MapScatterOp mapScatterOp,
326-
PadDistributionConfigFn padDistributionConfigFn) {
370+
FailureOr<MapScatterOp> foldIntoMapScatter(RewriterBase &rewriter,
371+
Operation *op,
372+
MapScatterOp mapScatterOp) {
327373
return llvm::TypeSwitch<Operation *, FailureOr<MapScatterOp>>(op)
328374
.Case<linalg::CopyOp>([&](linalg::CopyOp copyOp) {
329375
return foldIdentityLikeOpIntoMapScatter(rewriter, copyOp, mapScatterOp);
@@ -342,47 +388,9 @@ foldIntoMapScatter(RewriterBase &rewriter, Operation *op,
342388
return foldExtractSliceIntoMapScatter(rewriter, extractSliceOp,
343389
mapScatterOp);
344390
})
345-
.Case<tensor::PadOp>([&](tensor::PadOp padOp) {
346-
return foldPadIntoMapScatter(rewriter, padOp, mapScatterOp,
347-
padDistributionConfigFn);
348-
})
349391
.Default([](Operation *) { return failure(); });
350392
}
351393

352-
/// Starting from the `root`, iteratively combine any relayout op producers
353-
/// into a single iree_linalg_ext.map_scatter op. An identity map_scatter op
354-
/// is inserted before the root, and then the producers of the map_scatter op
355-
/// are folded into the map_scatter until an unsupported op is reached.
356-
static void
357-
combineRelayoutOpChain(RewriterBase &rewriter, MapScatterOp mapScatterOp,
358-
PadDistributionConfigFn padDistributionConfigFn) {
359-
Operation *relayoutOp = mapScatterOp.getInput().getDefiningOp();
360-
if (!relayoutOp) {
361-
return;
362-
}
363-
MapScatterOp combinedRelayoutOp = mapScatterOp;
364-
while (relayoutOp) {
365-
LDBG() << "Attempting to fold " << relayoutOp->getName()
366-
<< " into map_scatter op:\n"
367-
<< *relayoutOp;
368-
FailureOr<MapScatterOp> maybeCombinedRelayoutOp = foldIntoMapScatter(
369-
rewriter, relayoutOp, combinedRelayoutOp, padDistributionConfigFn);
370-
if (failed(maybeCombinedRelayoutOp)) {
371-
LDBG() << "Failed to fold " << relayoutOp->getName()
372-
<< " into map_scatter op";
373-
break;
374-
}
375-
combinedRelayoutOp = maybeCombinedRelayoutOp.value();
376-
LDBG() << "Successfully folded " << relayoutOp->getName()
377-
<< " into map_scatter. New map_scatter op:\n"
378-
<< combinedRelayoutOp;
379-
relayoutOp = combinedRelayoutOp.getInput().getDefiningOp();
380-
}
381-
if (combinedRelayoutOp.isIdentity()) {
382-
rewriter.replaceOp(combinedRelayoutOp, combinedRelayoutOp.getInput());
383-
}
384-
}
385-
386394
// Insert identity map_scatter op after the root and replace all uses.
387395
static MapScatterOp insertIdentityMapScatter(RewriterBase &rewriter,
388396
OpResult root) {
@@ -406,36 +414,50 @@ static MapScatterOp insertIdentityMapScatter(RewriterBase &rewriter,
406414
return mapScatterOp;
407415
}
408416

409-
static bool isSupportedRelayoutOp(Operation *op) {
417+
bool isSupportedRelayoutOp(Operation *op) {
410418
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
411419
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
412420
linalg::TransposeOp>(op);
413421
}
414422

415-
/// Returns the leaves of all relayout op chains in the funcOp. A relayout op
416-
/// chain is a sequence of relayout ops (defined by `isSupportedRelayoutOp`)
417-
/// for which the only users of the ops in the chain are relayout ops, except
418-
/// for the leaves of the chain. The leaves are simply relayout ops that have
419-
/// non relayout op users.
420-
static SmallVector<OpResult> getRelayoutLeaves(FunctionOpInterface funcOp) {
421-
SmallVector<OpResult> relayoutChainRoots;
422-
funcOp->walk([&relayoutChainRoots](Operation *op) {
423+
/// Insert identity map_scatter ops after the given operation if it is a valid
424+
/// leaf op of a relayout op chain. A relayout op chain is a sequence of
425+
/// relayout ops (defined by `isSupportedRelayoutOp`) for which the only users
426+
/// of the ops in the chain are relayout ops, except for the leaves of the
427+
/// chain. The leaves are simply relayout ops that have non relayout op users.
428+
/// The `controlFn` is a callback on the leaf OpResult that provides control
429+
/// over whether or not to insert a map_scatter op.
430+
struct InsertMapScatterOpPattern : public RewritePattern {
431+
InsertMapScatterOpPattern(MLIRContext *context,
432+
CombineRelayoutOpsControlFn controlFn = nullptr,
433+
PatternBenefit benefit = 1)
434+
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
435+
controlFn(controlFn) {}
436+
437+
LogicalResult matchAndRewrite(Operation *op,
438+
PatternRewriter &rewriter) const override {
423439
if (!isSupportedRelayoutOp(op)) {
424-
return WalkResult::advance();
440+
return failure();
425441
}
426442
// Relayout ops with only relayout op users are not leaves.
427443
auto isDimOrSupportedRelayoutOp = [](Operation *op) {
428444
return isSupportedRelayoutOp(op) || isa<tensor::DimOp>(op);
429445
};
430446
if (llvm::all_of(op->getUsers(), isDimOrSupportedRelayoutOp)) {
431-
return WalkResult::advance();
447+
return failure();
432448
}
433449
// All relayout ops have a single result.
434-
relayoutChainRoots.push_back(op->getResult(0));
435-
return WalkResult::advance();
436-
});
437-
return relayoutChainRoots;
438-
}
450+
OpResult leaf = op->getResult(0);
451+
if (controlFn && !controlFn(leaf)) {
452+
return failure();
453+
}
454+
(void)insertIdentityMapScatter(rewriter, leaf);
455+
return success();
456+
}
457+
458+
private:
459+
CombineRelayoutOpsControlFn controlFn;
460+
};
439461

440462
LogicalResult
441463
combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
@@ -499,24 +521,24 @@ combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
499521
IRRewriter rewriter(ctx);
500522
simplifyComplexRelayoutOps(rewriter, funcOp);
501523

502-
// Start from leaf ops, and combine producer relayout ops into a single
503-
// map_scatter.
504-
SmallVector<OpResult> relayoutLeaves = getRelayoutLeaves(funcOp);
505-
for (OpResult leaf : relayoutLeaves) {
506-
if (controlFn && !controlFn(leaf)) {
507-
continue;
508-
}
509-
MapScatterOp mapScatterOp = insertIdentityMapScatter(rewriter, leaf);
510-
combineRelayoutOpChain(rewriter, mapScatterOp, padDistributionConfigFn);
511-
}
512-
513-
// Cleanup any tensor.dim ops that may be present after relayout
514-
// combination.
515-
RewritePatternSet cleanupPatterns(ctx);
516-
memref::populateResolveRankedShapedTypeResultDimsPatterns(cleanupPatterns);
517-
if (failed(applyPatternsGreedily(funcOp, std::move(cleanupPatterns)))) {
524+
// Combine relayout operations into new the map_scatter ops.
525+
RewritePatternSet relayoutCombinationPatterns(ctx);
526+
relayoutCombinationPatterns.add<InsertMapScatterOpPattern>(ctx, controlFn);
527+
populateCombineRelayoutOpPatterns(relayoutCombinationPatterns,
528+
padDistributionConfigFn);
529+
memref::populateResolveRankedShapedTypeResultDimsPatterns(
530+
relayoutCombinationPatterns);
531+
if (failed(applyPatternsGreedily(funcOp,
532+
std::move(relayoutCombinationPatterns)))) {
518533
return failure();
519534
}
535+
536+
// Clean up any identity map_scatter ops after combining.
537+
funcOp->walk([&](MapScatterOp mapScatterOp) {
538+
if (mapScatterOp.isIdentity()) {
539+
rewriter.replaceOp(mapScatterOp, mapScatterOp.getInput());
540+
}
541+
});
520542
return success();
521543
}
522544

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#ifndef IREE_COMPILER_CODEGEN_COMMON_COMBINELAYOUTTRANSFORMATION_H_
88
#define IREE_COMPILER_CODEGEN_COMMON_COMBINELAYOUTTRANSFORMATION_H_
99

10+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1012
#include "mlir/Interfaces/FunctionInterfaces.h"
1113

1214
namespace mlir::iree_compiler {
@@ -51,6 +53,31 @@ enum class RelayoutCombinationScope { Dispatch, Workgroup };
5153
CombineRelayoutOpsControlFn
5254
getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope);
5355

56+
/// Returns true if the `op` type has a folding pattern into
57+
/// iree_linalg_ext.map_scatter.
58+
bool isSupportedRelayoutOp(Operation *op);
59+
60+
/// Fold the `op` into the `mapScatterOp` and return the resulting map_scatter,
61+
/// or failure if the transformation is not supported. The `op` is should be a
62+
/// supported relayout op, and not a tensor.pad. For tensor.pad, the folding is
63+
/// handled by `foldPadIntoMapScatter`, because it requires a
64+
/// `PadDistributionConfigFn`.
65+
FailureOr<IREE::LinalgExt::MapScatterOp>
66+
foldIntoMapScatter(RewriterBase &rewriter, Operation *op,
67+
IREE::LinalgExt::MapScatterOp mapScatterOp);
68+
69+
/// Fold a tensor.pad op into a iree_linalg_ext.map_scatter op, and separate
70+
/// the writing of padding values into a separate operation on the buffer that
71+
/// the map_scatter op is ultimately written into. The result buffer is taken
72+
/// from the direct consumer of the `mapScatterOp`, which is expected to be an
73+
/// `iree_codegen.store_to_buffer` op. Return failure if the result buffer is
74+
/// not found. The `padDistributionConfigFn` provides distribution configs for
75+
/// the writing of padding values to the corresponding output buffer.
76+
FailureOr<IREE::LinalgExt::MapScatterOp>
77+
foldPadIntoMapScatter(RewriterBase &rewriter, tensor::PadOp padOp,
78+
IREE::LinalgExt::MapScatterOp mapScatterOp,
79+
PadDistributionConfigFn padDistributionConfigFn);
80+
5481
/// Combines any layout/indexing transformation ops at the ends of a dispatch.
5582
/// Finds `iree_codegen.store_to_buffer` ops in the `funcOp`, and combines any
5683
/// layout transformation ops (like expand_shape, transpose, pack, etc.) that

0 commit comments

Comments
 (0)