Skip to content

Commit 643382b

Browse files
authored
[LinalgExt] Implement unit dim folding pattern for map_scatter (#21563)
Implements unit dim folding on the source of `iree_linalg_ext.map_scatter`. This is needed for the LLVMGPUVectorDistribute pipeline to avoid unit dim slices creating extra copies. The unit dim folding only happens on the source of the map_scatter, because we currently only see the op during codegen, where it is the last op in the dispatch. We generally shouldn't see unit dims on the output of the map_scatter, since unit dims in this phase will come from tiling, which only affects the input. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 510b024 commit 643382b

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,55 @@ Value rankExpandValue(RewriterBase &rewriter, Location loc, Value destVal,
575575
}
576576
}
577577

578+
struct DropMapScatterUnitDims final : public OpRewritePattern<MapScatterOp> {
579+
using OpRewritePattern<MapScatterOp>::OpRewritePattern;
580+
DropMapScatterUnitDims(MLIRContext *context,
581+
linalg::ControlDropUnitDims options,
582+
PatternBenefit benefit = 1)
583+
: OpRewritePattern<MapScatterOp>(context, benefit),
584+
options(std::move(options)) {}
585+
586+
LogicalResult matchAndRewrite(MapScatterOp mapScatterOp,
587+
PatternRewriter &rewriter) const override {
588+
auto inputType = dyn_cast<RankedTensorType>(mapScatterOp.getInputType());
589+
if (!inputType) {
590+
return failure();
591+
}
592+
Location loc = mapScatterOp.getLoc();
593+
FailureOr<Value> newInput = rankReduceOperand(
594+
rewriter, loc, /*startDim=*/0, /*numDims=*/mapScatterOp.getInputRank(),
595+
mapScatterOp.getInput(), mapScatterOp.getInputType(), options);
596+
if (failed(newInput)) {
597+
return failure();
598+
}
599+
600+
auto newInputType = cast<ShapedType>(newInput->getType());
601+
auto unitFoldingBuilder = [&](ArrayRef<BlockArgument> nonUnitIndices) {
602+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
603+
int nonUnitArgIdx = 0;
604+
return llvm::map_to_vector(
605+
llvm::seq<int64_t>(inputType.getRank()), [&](int64_t dim) -> Value {
606+
return inputType.getDimSize(dim) == 1
607+
? zero
608+
: cast<Value>(nonUnitIndices[nonUnitArgIdx++]);
609+
});
610+
};
611+
// The map_scatter op is generally only used in Codegen, where it is the
612+
// last op in the dispatch, so for now, we don't bother collapsing the
613+
// result shape and inserting an expansion after the op.
614+
rewriter.modifyOpInPlace(mapScatterOp, [&]() {
615+
mapScatterOp.getInputMutable().assign(newInput.value());
616+
mapScatterOp.insertTransformationAtStart(
617+
rewriter, unitFoldingBuilder,
618+
/*numSourceIndices=*/newInputType.getRank());
619+
});
620+
return success();
621+
}
622+
623+
private:
624+
linalg::ControlDropUnitDims options;
625+
};
626+
578627
struct DropGatherUnitDims final : public OpRewritePattern<GatherOp> {
579628
DropGatherUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options,
580629
PatternBenefit benefit = 1)
@@ -1059,8 +1108,8 @@ struct DropAttentionUnitDims final
10591108
void populateFoldUnitExtentDimsPatterns(
10601109
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
10611110
patterns.add<DropScatterUnitIndexDepth>(patterns.getContext());
1062-
patterns.add<DropGatherUnitDims, DropScatterUnitDims, DropAttentionUnitDims>(
1063-
patterns.getContext(), options);
1111+
patterns.add<DropGatherUnitDims, DropScatterUnitDims, DropAttentionUnitDims,
1112+
DropMapScatterUnitDims>(patterns.getContext(), options);
10641113
}
10651114

10661115
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/fold_unit_dims.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,63 @@ util.func public @scatter_no_change_output(%slice: tensor<1x2xf16>, %indices: te
132132
// SLICE: iree_linalg_ext.scatter
133133
// SLICE-SAME: ins(%[[UPDATE_SLICE]], %[[INDICES_SLICE]]
134134
// SLICE-SAME: outs(%[[ORIGINAL]]
135+
136+
// -----
137+
138+
util.func public @map_scatter(%input: tensor<1x2x1x1xf16>) -> tensor<2x2x2x2xf16> {
139+
%empty = tensor.empty() : tensor<2x2x2x2xf16>
140+
%0 = iree_linalg_ext.map_scatter %input into %empty {
141+
^bb0(%idx0: index, %idx1: index, %idx2: index, %idx3: index):
142+
%mask = arith.constant true
143+
iree_linalg_ext.yield %idx0, %idx1, %idx2, %idx3, %mask : index, index, index, index, i1
144+
} : tensor<1x2x1x1xf16> into tensor<2x2x2x2xf16> -> tensor<2x2x2x2xf16>
145+
util.return %0 : tensor<2x2x2x2xf16>
146+
}
147+
// RESHAPE-LABEL: util.func public @map_scatter
148+
// RESHAPE-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x2x1x1xf16>
149+
// RESHAPE-DAG: %[[DEST:.+]] = tensor.empty() : tensor<2x2x2x2xf16>
150+
// RESHAPE-DAG: %[[C0:.+]] = arith.constant 0 : index
151+
// RESHAPE-DAG: %[[INPUT_COLLAPSE:.+]] = tensor.collapse_shape %[[INPUT]]
152+
// RESHAPE-SAME: tensor<1x2x1x1xf16> into tensor<2xf16>
153+
// RESHAPE: iree_linalg_ext.map_scatter %[[INPUT_COLLAPSE]] into %[[DEST]]
154+
// RESHAPE: ^bb0(%[[IDX:.+]]: index):
155+
// RESHAPE: iree_linalg_ext.yield %[[C0]], %[[IDX]], %[[C0]], %[[C0]]
156+
157+
// SLICE-LABEL: util.func public @map_scatter
158+
// SLICE-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x2x1x1xf16>
159+
// SLICE-DAG: %[[DEST:.+]] = tensor.empty() : tensor<2x2x2x2xf16>
160+
// SLICE-DAG: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[INPUT]]
161+
// SLICE-SAME: tensor<1x2x1x1xf16> to tensor<2xf16>
162+
// SLICE: iree_linalg_ext.map_scatter %[[INPUT_SLICE]] into %[[DEST]]
163+
// SLICE: ^bb0(%[[IDX:.+]]: index):
164+
// SLICE: %[[C0:.+]] = arith.constant 0 : index
165+
// SLICE: iree_linalg_ext.yield %[[C0]], %[[IDX]], %[[C0]], %[[C0]]
166+
167+
// -----
168+
169+
util.func public @map_scatter_all_unit(%input: tensor<1x1xf16>) -> tensor<2x2xf16> {
170+
%empty = tensor.empty() : tensor<2x2xf16>
171+
%0 = iree_linalg_ext.map_scatter %input into %empty {
172+
^bb0(%idx0: index, %idx1: index):
173+
%mask = arith.constant true
174+
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
175+
} : tensor<1x1xf16> into tensor<2x2xf16> -> tensor<2x2xf16>
176+
util.return %0 : tensor<2x2xf16>
177+
}
178+
// RESHAPE-LABEL: util.func public @map_scatter_all_unit
179+
// RESHAPE-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x1xf16>
180+
// RESHAPE-DAG: %[[DEST:.+]] = tensor.empty() : tensor<2x2xf16>
181+
// RESHAPE-DAG: %[[C0:.+]] = arith.constant 0 : index
182+
// RESHAPE-DAG: %[[INPUT_COLLAPSE:.+]] = tensor.collapse_shape %[[INPUT]]
183+
// RESHAPE-SAME: tensor<1x1xf16> into tensor<1xf16>
184+
// RESHAPE: iree_linalg_ext.map_scatter %[[INPUT_COLLAPSE]] into %[[DEST]]
185+
// RESHAPE: iree_linalg_ext.yield %[[C0]], %[[C0]]
186+
187+
// SLICE-LABEL: util.func public @map_scatter_all_unit
188+
// SLICE-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x1xf16>
189+
// SLICE-DAG: %[[DEST:.+]] = tensor.empty() : tensor<2x2xf16>
190+
// SLICE-DAG: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[INPUT]]
191+
// SLICE-SAME: tensor<1x1xf16> to tensor<1xf16>
192+
// SLICE: iree_linalg_ext.map_scatter %[[INPUT_SLICE]] into %[[DEST]]
193+
// SLICE: %[[C0:.+]] = arith.constant 0 : index
194+
// SLICE: iree_linalg_ext.yield %[[C0]], %[[C0]]

0 commit comments

Comments
 (0)