Skip to content

Commit 340ffbb

Browse files
authored
[LinalgExt] Drop the unit dims on scatter ops 2/3 (iree-org#19450)
This change adds patterns to drop the unit dims of a `iree_linalg_ext.scatter`'s `%updates` tensor. It only drops the leading unit dimensions from the portion of `updates` that represents the indexed dimensions. See the main issue iree-org#19091 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 0820f10 commit 340ffbb

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final
486486
linalg::ControlFusionFn controlFoldingReshapes;
487487
};
488488

489+
/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
490+
/// The dims in `update` between the batch dims and the continuous slice
491+
/// represent the indexed dimensions. Remove the leading unit dims from the
492+
/// indexed dims.
493+
struct FoldScatterNonIterationUnitDims final
494+
: public OpRewritePattern<ScatterOp> {
495+
FoldScatterNonIterationUnitDims(MLIRContext *context,
496+
linalg::ControlDropUnitDims options,
497+
PatternBenefit benefit = 1)
498+
: OpRewritePattern<ScatterOp>(context, benefit),
499+
options(std::move(options)) {}
500+
501+
LogicalResult matchAndRewrite(ScatterOp scatterOp,
502+
PatternRewriter &rewriter) const override {
503+
if (options.rankReductionStrategy !=
504+
linalg::ControlDropUnitDims::RankReductionStrategy::
505+
ReassociativeReshape) {
506+
return rewriter.notifyMatchFailure(
507+
scatterOp, "Only reassociative reshape strategy supported");
508+
}
509+
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
510+
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();
511+
512+
// Find the number of leading unit dimensions
513+
int64_t rankOfContiguousSlice =
514+
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
515+
ArrayRef<int64_t> indexedDims =
516+
scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
517+
int64_t numDimsToDrop =
518+
llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
519+
scatterOp.getUpdateSliceShape().begin() - 1;
520+
521+
int64_t batchRank = scatterOp.getBatchRank();
522+
llvm::erase_if(canDrop, [&](unsigned dimPos) {
523+
return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
524+
});
525+
if (canDrop.empty()) {
526+
return failure();
527+
}
528+
529+
SmallVector<int64_t> droppedUpdateShape;
530+
droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
531+
for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
532+
if (!llvm::is_contained(canDrop, idx)) {
533+
droppedUpdateShape.push_back(dimLen);
534+
}
535+
}
536+
537+
auto reassoc =
538+
getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
539+
assert(reassoc.has_value() && "expected reassociation to be valid");
540+
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
541+
scatterOp.getLoc(),
542+
RankedTensorType::get(droppedUpdateShape,
543+
scatterOp.getUpdateType().getElementType()),
544+
scatterOp.getUpdates(), reassoc.value());
545+
546+
rewriter.modifyOpInPlace(scatterOp, [&]() {
547+
scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
548+
});
549+
return success();
550+
}
551+
552+
private:
553+
linalg::ControlDropUnitDims options;
554+
};
555+
489556
} // namespace
490557

491558
/// Return the `reassociation` indices to use to collapse the operand when the
@@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
708775
patterns.getContext(), controlFoldingReshapes);
709776
}
710777

778+
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
779+
auto fusionOp = cast<LinalgFusionOpInterface>(op);
780+
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
781+
}
782+
783+
void populateFoldUnitExtentDimsPatterns(
784+
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
785+
patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
786+
}
787+
711788
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps(
2525
RewritePatternSet &patterns,
2626
const linalg::ControlFusionFn &controlFusionFn);
2727

28+
/// Default function to drop unit dims for for linalgext ops.
29+
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op);
30+
31+
/// Drop unit extent dims from linalg ext ops
32+
void populateFoldUnitExtentDimsPatterns(
33+
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options);
34+
2835
/// Helper struct to hold the results of collapsing an operation.
2936
struct CollapseResult {
3037
SmallVector<Value> results;

compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
15+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
16+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
1517
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
1618
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
1719
#include "iree/compiler/DispatchCreation/Passes.h"
@@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() {
151153
if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
152154
return SmallVector<unsigned>{};
153155
}
156+
if (isa<IREE::LinalgExt::LinalgExtOp>(op)) {
157+
return IREE::LinalgExt::defaultControlDropUnitDims(op);
158+
}
154159
return defaultFn(op);
155160
};
156161
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
162+
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
163+
options);
157164
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
158165
if (failed(
159166
applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) {

compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,47 @@ module @fold_stream_parameter {
106106
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
107107
// CHECK: util.func public @fold_stream_parameter
108108
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>
109+
110+
// -----
111+
112+
util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
113+
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
114+
^bb0(%arg3: f16, %arg4: f16):
115+
iree_linalg_ext.yield %arg3 : f16
116+
} -> tensor<?x2x16x4x128xf16>
117+
util.return %0 : tensor<?x2x16x4x128xf16>
118+
}
119+
// CHECK-LABEL: func public @scatter0
120+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
121+
// CHECK-SAME: to tensor<?x2x16x4x128xf16>
122+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
123+
// CHECK-SAME: ins(%[[COLLAPSE]]
124+
125+
// -----
126+
127+
util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
128+
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
129+
^bb0(%arg3: f16, %arg4: f16):
130+
iree_linalg_ext.yield %arg3 : f16
131+
} -> tensor<?x2x16x4x128xf16>
132+
util.return %0 : tensor<?x2x16x4x128xf16>
133+
}
134+
// CHECK-LABEL: func public @scatter1
135+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
136+
// CHECK-SAME: to tensor<?x16x4x128xf16>
137+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
138+
// CHECK-SAME: ins(%[[COLLAPSE]]
139+
140+
// -----
141+
142+
// TODO: remove other unit dims.
143+
util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
144+
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
145+
^bb0(%arg3: f16, %arg4: f16):
146+
iree_linalg_ext.yield %arg3 : f16
147+
} -> tensor<?x2x1x4x128xf16>
148+
util.return %0 : tensor<?x2x1x4x128xf16>
149+
}
150+
// CHECK-LABEL: func public @scatter_noop
151+
// CHECK-NOT: tensor.collapse_shape
152+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter

0 commit comments

Comments
 (0)