Skip to content

Commit 77ba691

Browse files
authored
[mlir][linalg] Fix FoldReshapeWithGenericOpByCollapsing insertion point (#133476)
Fixes dominance verifier error with `FoldReshapeWithGenericOpByCollapsing` by setting the insertion point after `producer`. The `tensor.collapse_shape` op only has a single operand (`producer`) so it is safe to insert after the producer. Signed-off-by: Ian Wood <[email protected]>
1 parent 864c76a commit 77ba691

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,9 @@ struct FoldReshapeWithGenericOpByCollapsing
18971897
"fusion blocked by control function");
18981898
}
18991899

1900+
// Set the insertion point after `producer` because there could be uses
1901+
// of `producer` between it and the `tensor.collapse_shape` op.
1902+
rewriter.setInsertionPointAfter(producer);
19001903
std::optional<CollapseResult> collapseResult =
19011904
collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
19021905
if (!collapseResult) {

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,3 +798,35 @@ func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse(%arg0 : tensor
798798
// CONTROL-SAME: ins(%[[ARG0]],
799799
// CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
800800
// CONTROL: return %[[COLLAPSE]]
801+
802+
// -----
803+
804+
// Check that new ops are inserted at `%0` because `%0` is also used by `tensor.dim`.
805+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
806+
func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
807+
%sz0: index, %sz1: index) -> (tensor<?xf32>, index) {
808+
%c0 = arith.constant 0 : index
809+
%init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
810+
%0 = linalg.generic {
811+
indexing_maps = [#map0, #map0],
812+
iterator_types = ["parallel", "parallel"]}
813+
ins(%arg0 : tensor<?x?xf32>)
814+
outs(%init : tensor<?x?xf32>) {
815+
^bb0(%b0 : f32, %b1 : f32):
816+
%out = arith.negf %b0 : f32
817+
linalg.yield %out : f32
818+
} -> tensor<?x?xf32>
819+
%dim = tensor.dim %0, %c0 : tensor<?x?xf32>
820+
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
821+
return %1, %dim : tensor<?xf32>, index
822+
}
823+
824+
// CHECK-LABEL: func @fuse_by_collapsing_correct_insertion
825+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
826+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
827+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
828+
// CHECK: %[[OUT:.+]] = linalg.generic
829+
// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<?xf32>)
830+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
831+
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
832+
// CHECK: return %[[OUT]], %[[DIM]]

0 commit comments

Comments
 (0)