diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index d6b7ab0a980eb..bf70597d5ddfe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1897,6 +1897,9 @@ struct FoldReshapeWithGenericOpByCollapsing "fusion blocked by control function"); } + // Set the insertion point after `producer` because there could be uses + // of `producer` between it and the `tensor.collapse_shape` op. + rewriter.setInsertionPointAfter(producer); std::optional collapseResult = collapseOpIterationDims(producer, collapsableIterationDims, rewriter); if (!collapseResult) { diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 21178fd7e783f..dba53b4192cd5 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -798,3 +798,35 @@ func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse(%arg0 : tensor // CONTROL-SAME: ins(%[[ARG0]], // CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] // CONTROL: return %[[COLLAPSE]] + +// ----- + +// Check that new ops are inserted at `%0` because `%0` is also used by `tensor.dim`. +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor, + %sz0: index, %sz1: index) -> (tensor, index) { + %c0 = arith.constant 0 : index + %init = tensor.empty(%sz1, %sz0) : tensor + %0 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%init : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %out = arith.negf %b0 : f32 + linalg.yield %out : f32 + } -> tensor + %dim = tensor.dim %0, %c0 : tensor + %1 = tensor.collapse_shape %0 [[0, 1]] : tensor into tensor + return %1, %dim : tensor, index +} + +// CHECK-LABEL: func @fuse_by_collapsing_correct_insertion +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[OUT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[COLLAPSE]] : tensor) +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]] +// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK: return %[[OUT]], %[[DIM]]