Skip to content

Commit f7f7ea8

Browse files
authored
[Codegen] Add reshape map_scatter folding to BlockDynamicDimensions (#21047)
Refactors the map_scatter folding transformations with tensor.collapse_shape and tensor.expand_shape into pattern rewrites, and adds the patterns to the BlockDynamicDimensions pass. This allows the CombineLayoutTransformation pass to be run before BlockDynamicDimensions, without creating additional reshapes in the IR. Signed-off-by: Max Dawkins <[email protected]>
1 parent bfc106e commit f7f7ea8

File tree

5 files changed

+151
-39
lines changed

5 files changed

+151
-39
lines changed

compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
88
#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
9+
#include "iree/compiler/Codegen/Common/Transforms.h"
910
#include "iree/compiler/Codegen/Transforms/Transforms.h"
1011
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1112
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
@@ -374,6 +375,7 @@ void BlockDynamicDimensionsPass::runOnOperation() {
374375
// "pushed-down" `tensor.collapse_shape` operation with their interface
375376
// bindings or `tensor.empty` operations.
376377
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
378+
populateCombineRelayoutOpPatterns(bubbleExpandShapePatterns);
377379
populateFoldTensorReshapeIntoBufferPatterns(bubbleExpandShapePatterns);
378380
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
379381
tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns);
@@ -409,6 +411,7 @@ void BlockDynamicDimensionsPass::runOnOperation() {
409411
// Add patterns to fold the remaining reshape operation with their interface
410412
// bindings or `tensor.empty` operations.
411413
populateReshapeToInterfaceTensorPatterns(removeBarrierOpsPatterns);
414+
populateCombineRelayoutOpPatterns(removeBarrierOpsPatterns);
412415
populateFoldTensorReshapeIntoBufferPatterns(removeBarrierOpsPatterns);
413416
tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns);
414417
linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
88
#include "iree/compiler/Codegen/Common/Passes.h"
9+
#include "iree/compiler/Codegen/Common/Transforms.h"
910
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1011
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
1112
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
@@ -111,43 +112,6 @@ static MapScatterOp foldTransposeIntoMapScatter(RewriterBase &rewriter,
111112
return mapScatterOp;
112113
}
113114

114-
/// Fold a tensor::ExpandShapeOp or tensor::CollapseShapeOp into a consumer
115-
/// `mapScatterOp`, by linearizing and then delinearizing the source indices
116-
/// of the `mapScatterOp`s index transformation.
117-
template <typename ReshapeOpTy>
118-
static MapScatterOp foldReshapeIntoMapScatter(RewriterBase &rewriter,
119-
ReshapeOpTy reshapeOp,
120-
MapScatterOp mapScatterOp) {
121-
assert(mapScatterOp.getInput() == reshapeOp->getResult(0) &&
122-
"expected reshapeOp to be the producer of mapScatterOp");
123-
Location loc = reshapeOp->getLoc();
124-
OpBuilder::InsertionGuard g(rewriter);
125-
rewriter.setInsertionPointAfter(reshapeOp);
126-
SmallVector<OpFoldResult> srcDims =
127-
tensor::getMixedSizes(rewriter, loc, reshapeOp.getSrc());
128-
// There can be leftover tensor.dim ops consuming the result of the reshape,
129-
// but they will be folded into some affine.apply ops on the source sizes by
130-
// later cleanup patterns.
131-
SmallVector<OpFoldResult> resultDims =
132-
tensor::getMixedSizes(rewriter, loc, reshapeOp.getResult());
133-
134-
auto indexTransformBuilder =
135-
[&](ArrayRef<BlockArgument> srcIndices) -> SmallVector<Value> {
136-
auto linearizeIndexOp = rewriter.create<affine::AffineLinearizeIndexOp>(
137-
mapScatterOp->getLoc(), srcIndices, srcDims, /*disjoint=*/true);
138-
auto delinearizeIndexOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
139-
mapScatterOp->getLoc(), linearizeIndexOp.getResult(), resultDims,
140-
/*hasOuterBound=*/true);
141-
return delinearizeIndexOp->getResults();
142-
};
143-
rewriter.modifyOpInPlace(mapScatterOp, [&]() {
144-
mapScatterOp.insertTransformationAtStart(rewriter, indexTransformBuilder,
145-
srcDims.size());
146-
mapScatterOp.getInputMutable().assign(reshapeOp->getOperand(0));
147-
});
148-
return mapScatterOp;
149-
}
150-
151115
/// Fold an `extractSliceOp` into a consumer `mapScatterOp` by applying a mask
152116
/// based on the bounds of the extractSliceOp. Currently, only zero offsets and
153117
/// unit strides are supported.
@@ -368,10 +332,11 @@ foldIntoMapScatter(RewriterBase &rewriter, Operation *op,
368332
return foldTransposeIntoMapScatter(rewriter, transposeOp, mapScatterOp);
369333
})
370334
.Case<tensor::ExpandShapeOp>([&](tensor::ExpandShapeOp expandOp) {
371-
return foldReshapeIntoMapScatter(rewriter, expandOp, mapScatterOp);
335+
return foldExpandShapeIntoMapScatter(rewriter, expandOp, mapScatterOp);
372336
})
373337
.Case<tensor::CollapseShapeOp>([&](tensor::CollapseShapeOp collapseOp) {
374-
return foldReshapeIntoMapScatter(rewriter, collapseOp, mapScatterOp);
338+
return foldCollapseShapeIntoMapScatter(rewriter, collapseOp,
339+
mapScatterOp);
375340
})
376341
.Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp extractSliceOp) {
377342
return foldExtractSliceIntoMapScatter(rewriter, extractSliceOp,

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,110 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/Transforms.h"
8+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
89
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1011

1112
#define DEBUG_TYPE "iree-codegen-common-transforms"
1213

1314
namespace mlir::iree_compiler {
1415

16+
//===----------------------------------------------------------------------===//
17+
// Combining Layout Transformation Ops
18+
//===----------------------------------------------------------------------===//
19+
20+
/// Fold a tensor::ExpandShapeOp or tensor::CollapseShapeOp into a consumer
21+
/// `mapScatterOp`, by linearizing and then delinearizing the source indices
22+
/// of the `mapScatterOp`s index transformation.
23+
template <typename ReshapeOpTy>
24+
static IREE::LinalgExt::MapScatterOp
25+
foldReshapeIntoMapScatter(RewriterBase &rewriter, ReshapeOpTy reshapeOp,
26+
IREE::LinalgExt::MapScatterOp mapScatterOp) {
27+
assert(mapScatterOp.getInput() == reshapeOp->getResult(0) &&
28+
"expected reshapeOp to be the producer of mapScatterOp");
29+
Location loc = reshapeOp->getLoc();
30+
OpBuilder::InsertionGuard g(rewriter);
31+
rewriter.setInsertionPointAfter(reshapeOp);
32+
SmallVector<OpFoldResult> srcDims =
33+
tensor::getMixedSizes(rewriter, loc, reshapeOp.getSrc());
34+
// There can be leftover tensor.dim ops consuming the result of the reshape,
35+
// but they are expected to be folded into some affine.apply ops on the source
36+
// sizes by later cleanup patterns.
37+
SmallVector<OpFoldResult> resultDims =
38+
tensor::getMixedSizes(rewriter, loc, reshapeOp.getResult());
39+
40+
auto indexTransformBuilder =
41+
[&](ArrayRef<BlockArgument> srcIndices) -> SmallVector<Value> {
42+
auto linearizeIndexOp = rewriter.create<affine::AffineLinearizeIndexOp>(
43+
mapScatterOp->getLoc(), srcIndices, srcDims, /*disjoint=*/true);
44+
auto delinearizeIndexOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
45+
mapScatterOp->getLoc(), linearizeIndexOp.getResult(), resultDims,
46+
/*hasOuterBound=*/true);
47+
return delinearizeIndexOp->getResults();
48+
};
49+
rewriter.modifyOpInPlace(mapScatterOp, [&]() {
50+
mapScatterOp.insertTransformationAtStart(rewriter, indexTransformBuilder,
51+
srcDims.size());
52+
mapScatterOp.getInputMutable().assign(reshapeOp->getOperand(0));
53+
});
54+
return mapScatterOp;
55+
}
56+
57+
IREE::LinalgExt::MapScatterOp
58+
foldExpandShapeIntoMapScatter(RewriterBase &rewriter,
59+
tensor::ExpandShapeOp expandShapeOp,
60+
IREE::LinalgExt::MapScatterOp mapScatterOp) {
61+
return foldReshapeIntoMapScatter(rewriter, expandShapeOp, mapScatterOp);
62+
}
63+
64+
IREE::LinalgExt::MapScatterOp
65+
foldCollapseShapeIntoMapScatter(RewriterBase &rewriter,
66+
tensor::CollapseShapeOp collapseShapeOp,
67+
IREE::LinalgExt::MapScatterOp mapScatterOp) {
68+
return foldReshapeIntoMapScatter(rewriter, collapseShapeOp, mapScatterOp);
69+
}
70+
71+
namespace {
72+
73+
struct FoldExpandShapeIntoMapScatterPattern
74+
: public OpRewritePattern<IREE::LinalgExt::MapScatterOp> {
75+
using OpRewritePattern<IREE::LinalgExt::MapScatterOp>::OpRewritePattern;
76+
77+
LogicalResult matchAndRewrite(IREE::LinalgExt::MapScatterOp mapScatterOp,
78+
PatternRewriter &rewriter) const override {
79+
auto expandOp =
80+
mapScatterOp.getInput().getDefiningOp<tensor::ExpandShapeOp>();
81+
if (!expandOp) {
82+
return failure();
83+
}
84+
(void)foldExpandShapeIntoMapScatter(rewriter, expandOp, mapScatterOp);
85+
return success();
86+
}
87+
};
88+
89+
struct FoldCollapseShapeIntoMapScatterPattern
90+
: public OpRewritePattern<IREE::LinalgExt::MapScatterOp> {
91+
using OpRewritePattern<IREE::LinalgExt::MapScatterOp>::OpRewritePattern;
92+
93+
LogicalResult matchAndRewrite(IREE::LinalgExt::MapScatterOp mapScatterOp,
94+
PatternRewriter &rewriter) const override {
95+
auto collapseOp =
96+
mapScatterOp.getInput().getDefiningOp<tensor::CollapseShapeOp>();
97+
if (!collapseOp) {
98+
return failure();
99+
}
100+
(void)foldCollapseShapeIntoMapScatter(rewriter, collapseOp, mapScatterOp);
101+
return success();
102+
}
103+
};
104+
105+
} // namespace
106+
107+
void populateCombineRelayoutOpPatterns(RewritePatternSet &patterns) {
108+
patterns.add<FoldCollapseShapeIntoMapScatterPattern,
109+
FoldExpandShapeIntoMapScatterPattern>(patterns.getContext());
110+
}
111+
15112
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
16113
/// `tensor.expand_shape(tensor.extract_slice)`.
17114
/// For this transformation to be possible, the slice must be fully contiguous

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ class ConfigTrackingListener : public RewriterBase::Listener {
2626
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
2727
};
2828

29+
/// Fold a tensor::ExpandShapeOp into a consumer `mapScatterOp`, by linearizing
30+
/// and then delinearizing the source indices of the `mapScatterOp`s index
31+
/// transformation.
32+
IREE::LinalgExt::MapScatterOp
33+
foldExpandShapeIntoMapScatter(RewriterBase &rewriter,
34+
tensor::ExpandShapeOp expandShapeOp,
35+
IREE::LinalgExt::MapScatterOp mapScatterOp);
36+
37+
/// Fold a tensor::CollapseShapeOp into a consumer `mapScatterOp`, by
38+
/// linearizing and then delinearizing the source indices of the
39+
/// `mapScatterOp`s index transformation.
40+
IREE::LinalgExt::MapScatterOp
41+
foldCollapseShapeIntoMapScatter(RewriterBase &rewriter,
42+
tensor::CollapseShapeOp collapseShapeOp,
43+
IREE::LinalgExt::MapScatterOp mapScatterOp);
44+
2945
using IGEMMConfigFn =
3046
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
3147
using IGEMMControlFn = std::function<bool(Operation *)>;
@@ -94,6 +110,9 @@ void populateReplaceSlowMinMaxOpsPatterns(RewritePatternSet &patterns);
94110

95111
void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns);
96112

113+
/// Populate patterns to fold relayout operations into map_scatter ops.
114+
void populateCombineRelayoutOpPatterns(RewritePatternSet &patterns);
115+
97116
} // namespace mlir::iree_compiler
98117

99118
#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_

compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,31 @@ func.func @check_bubble_up_patterns(%arg0 : tensor<4x32x?x32x?x32xf32>, %arg1 :
399399
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x?x32x?x32xf32>
400400
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
401401
// CHECK: return %[[COLLAPSED]]
402+
403+
// -----
404+
405+
func.func @block_dims_with_map_scatter(%size: index) -> tensor<?xf32> {
406+
%0 = util.assume.int %size<umin = 16, umax = 4080, udiv = 16> : index
407+
%cst = arith.constant 0.0 : f32
408+
%1 = tensor.empty(%0) : tensor<?xf32>
409+
%2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
410+
iterator_types = ["parallel"]}
411+
outs(%1 : tensor<?xf32>) {
412+
^bb0(%out: f32):
413+
linalg.yield %cst : f32
414+
} -> tensor<?xf32>
415+
%3 = iree_linalg_ext.map_scatter %2 into %1 {
416+
^bb0(%arg0: index):
417+
%true = arith.constant true
418+
iree_linalg_ext.yield %arg0, %true : index, i1
419+
} : tensor<?xf32> into tensor<?xf32> -> tensor<?xf32>
420+
return %3 : tensor<?xf32>
421+
}
422+
// Check that the reshapes are able to be folded into the map_scatter op
423+
//
424+
// CHECK-LABEL: func @block_dims_with_map_scatter(
425+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty{{.*}} tensor<?x16xf32>
426+
// CHECK: %[[GENERIC:.+]] = linalg.generic
427+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x16xf32>)
428+
// CHECK: %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_scatter
429+
// CHECK: return %[[MAP_SCATTER]]

0 commit comments

Comments
 (0)