diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index f97ed3d6d5111..7bffac5425f8f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -172,8 +172,9 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. + auto producerResult = cast(fusedOperand->get()); AffineMap producerResultIndexMap = - producer.getMatchingIndexingMap(producer.getDpsInitOperand(0)); + producer.getIndexingMapMatchingResult(producerResult); if (!producerResultIndexMap.isPermutation()) return false; diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 66fc55fadf8fa..bc55c12c02f29 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1014,3 +1014,69 @@ module { // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] + +// ----- + +// In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate. +// The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with +// an additional result and ana dditional output indexing map that is not a permutation / not invertible. +// The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map +// is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument. +// When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map. + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)> +module { + func.func @fuse_only_as_long_as_result_map_is_permutation(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<1x2x2x1xf32> + %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%0 : tensor<1x2x2x1xf32>) { + ^bb0(%out: f32): + %6 = linalg.index 1 : index + %7 = linalg.index 2 : index + %8 = arith.cmpi ult, %6, %c1 : index + %9 = arith.cmpi ult, %7, %c2 : index + %10 = arith.andi %8, %9 : i1 + %11 = scf.if %10 -> (f32) { + %extracted = tensor.extract %arg1[%c0, %6, %7, %c0] : tensor<1x1x2x1xf32> + scf.yield %extracted : f32 + } else { + scf.yield %cst : f32 + } + linalg.yield %11 : f32 + } -> tensor<1x2x2x1xf32> + %2 = tensor.empty() : tensor<4x1x1x1xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x2x2x1xf32>) outs(%2 : tensor<4x1x1x1xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<4x1x1x1xf32> + %4 = tensor.empty() : tensor<1x1x2x4xf32> + %expanded = tensor.expand_shape %4 [[0], [1], [2], [3, 4, 5]] output_shape [1, 1, 2, 4, 1, 1] : tensor<1x1x2x4xf32> into tensor<1x1x2x4x1x1xf32> + %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %3 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) outs(%expanded : tensor<1x1x2x4x1x1xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.addf %6, %out : f32 + linalg.yield %7 : f32 + } -> tensor<1x1x2x4x1x1xf32> + %collapsed = tensor.collapse_shape %5 [[0], [1], [2], [3, 4, 5]] : tensor<1x1x2x4x1x1xf32> into tensor<1x1x2x4xf32> + return %collapsed : tensor<1x1x2x4xf32> + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)> +// CHECK: func.func @fuse_only_as_long_as_result_map_is_permutation +// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x2x1xf32>, %[[ARG1:.*]]: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> { +// CHECK-DAG: %[[OUT0:.+]] = tensor.empty() : tensor<1x2x2x1xf32> +// CHECK-DAG: %[[OUT1:.+]] = tensor.empty() : tensor<4x1x1x1xf32> +// CHECK: %[[FUSED:.+]]:2 = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: outs(%[[OUT0]], %[[OUT1]] : tensor<1x2x2x1xf32>, tensor<4x1x1x1xf32>) +// CHECK-NOT: linalg.generic +// CHECK: tensor.expand_shape +// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) \ No newline at end of file