Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {

// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
auto producerResult = cast<OpResult>(fusedOperand->get());
AffineMap producerResultIndexMap =
producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
producer.getIndexingMapMatchingResult(producerResult);
if (!producerResultIndexMap.isPermutation())
return false;

Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,69 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]

// -----

// In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate.
// The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with
// an additional result and ana dditional output indexing map that is not a permutation / not invertible.
// The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map
// is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument.
// When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map.

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>
module {
func.func @fuse_only_as_long_as_result_map_is_permutation(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<1x2x2x1xf32>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%0 : tensor<1x2x2x1xf32>) {
^bb0(%out: f32):
%6 = linalg.index 1 : index
%7 = linalg.index 2 : index
%8 = arith.cmpi ult, %6, %c1 : index
%9 = arith.cmpi ult, %7, %c2 : index
%10 = arith.andi %8, %9 : i1
%11 = scf.if %10 -> (f32) {
%extracted = tensor.extract %arg1[%c0, %6, %7, %c0] : tensor<1x1x2x1xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst : f32
}
linalg.yield %11 : f32
} -> tensor<1x2x2x1xf32>
%2 = tensor.empty() : tensor<4x1x1x1xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x2x2x1xf32>) outs(%2 : tensor<4x1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<4x1x1x1xf32>
%4 = tensor.empty() : tensor<1x1x2x4xf32>
%expanded = tensor.expand_shape %4 [[0], [1], [2], [3, 4, 5]] output_shape [1, 1, 2, 4, 1, 1] : tensor<1x1x2x4xf32> into tensor<1x1x2x4x1x1xf32>
%5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %3 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) outs(%expanded : tensor<1x1x2x4x1x1xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%6 = arith.mulf %in, %in_0 : f32
%7 = arith.addf %6, %out : f32
linalg.yield %7 : f32
} -> tensor<1x1x2x4x1x1xf32>
%collapsed = tensor.collapse_shape %5 [[0], [1], [2], [3, 4, 5]] : tensor<1x1x2x4x1x1xf32> into tensor<1x1x2x4xf32>
return %collapsed : tensor<1x1x2x4xf32>
}
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)>
// CHECK: func.func @fuse_only_as_long_as_result_map_is_permutation
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x2x1xf32>, %[[ARG1:.*]]: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> {
// CHECK-DAG: %[[OUT0:.+]] = tensor.empty() : tensor<1x2x2x1xf32>
// CHECK-DAG: %[[OUT1:.+]] = tensor.empty() : tensor<4x1x1x1xf32>
// CHECK: %[[FUSED:.+]]:2 = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: outs(%[[OUT0]], %[[OUT1]] : tensor<1x2x2x1xf32>, tensor<4x1x1x1xf32>)
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
Loading