Skip to content

Commit b9a06f6

Browse files
authored
[LinalgExt] Fix scatter unique_indices when dropping unit dims (#22362)
Correctly preserve `unique_indices` when creating a new scatter op in `DropScatterUnitDims`. There doesn't seem to be any other incorrect usages of `ScatterOp::create`. Fixes #22361 Signed-off-by: Ian Wood <[email protected]>
1 parent 7e81e54 commit b9a06f6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ struct DropScatterUnitDims final : public OpRewritePattern<ScatterOp> {
815815
auto newScatter = ScatterOp::create(
816816
rewriter, scatterOp.getLoc(), TypeRange{original.getType()},
817817
ValueRange{updates, indices}, ValueRange{original},
818-
scatterOp.getDimensionMap());
818+
scatterOp.getDimensionMap(), scatterOp.getUniqueIndices());
819819
rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
820820
newScatter.getRegion().begin());
821821
rewriter.replaceOp(scatterOp,

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/fold_unit_dims.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ util.func public @scatter_batch_and_slice_dims(%slice: tensor<4x1x1x4xf16>, %ind
7878
// RESHAPE: %[[UPDATE_COLLAPSE:.+]] = tensor.collapse_shape %[[UPDATE]]
7979
// RESHAPE-SAME: tensor<4x1x1x4xf16> into tensor<4x4xf16>
8080
// RESHAPE: iree_linalg_ext.scatter
81+
// RESHAPE-SAME: unique_indices(true)
8182
// RESHAPE-SAME: ins(%[[UPDATE_COLLAPSE]], %[[INDICES_COLLAPSE]]
8283
// RESHAPE-SAME: outs(%[[ORIGINAL_COLLAPSE]]
8384

@@ -94,14 +95,15 @@ util.func public @scatter_batch_and_slice_dims(%slice: tensor<4x1x1x4xf16>, %ind
9495
// SLICE: %[[UPDATE_SLICE2:.+]] = tensor.extract_slice %[[UPDATE_SLICE]]
9596
// SLICE-SAME: tensor<4x1x4xf16> to tensor<4x4xf16>
9697
// SLICE: iree_linalg_ext.scatter
98+
// SLICE-SAME: unique_indices(true)
9799
// SLICE-SAME: ins(%[[UPDATE_SLICE2]], %[[INDICES_SLICE]]
98100
// SLICE-SAME: outs(%[[ORIGINAL_SLICE]]
99101

100102
// -----
101103

102104
util.func public @scatter_no_change_output(%slice: tensor<1x2xf16>, %indices: tensor<1x2x2xi64>) -> tensor<2x2xf16> {
103105
%empty = tensor.empty() : tensor<2x2xf16>
104-
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
106+
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(false)
105107
ins(%slice, %indices: tensor<1x2xf16>, tensor<1x2x2xi64>)
106108
outs(%empty: tensor<2x2xf16>){
107109
^bb0(%in : f16, %out : f16):
@@ -118,6 +120,7 @@ util.func public @scatter_no_change_output(%slice: tensor<1x2xf16>, %indices: te
118120
// RESHAPE: %[[UPDATE_COLLAPSE:.+]] = tensor.collapse_shape %[[UPDATE]]
119121
// RESHAPE-SAME: tensor<1x2xf16> into tensor<2xf16>
120122
// RESHAPE: iree_linalg_ext.scatter
123+
// RESHAPE-SAME: unique_indices(false)
121124
// RESHAPE-SAME: ins(%[[UPDATE_COLLAPSE]], %[[INDICES_COLLAPSE]]
122125
// RESHAPE-SAME: outs(%[[ORIGINAL]]
123126

@@ -130,6 +133,7 @@ util.func public @scatter_no_change_output(%slice: tensor<1x2xf16>, %indices: te
130133
// SLICE: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATE]]
131134
// SLICE-SAME: tensor<1x2xf16> to tensor<2xf16>
132135
// SLICE: iree_linalg_ext.scatter
136+
// SLICE-SAME: unique_indices(false)
133137
// SLICE-SAME: ins(%[[UPDATE_SLICE]], %[[INDICES_SLICE]]
134138
// SLICE-SAME: outs(%[[ORIGINAL]]
135139

0 commit comments

Comments
 (0)