Skip to content

Commit f8d832f

Browse files
fabrizio-indirlimahesh-attarde
authored andcommitted
[mlir][linalg] Fix to Elementwise Fusion when preserving results (llvm#149843)
In the linalg ElementwiseOpFusion transform, a pre-requisite for the fusion between a producer and consumer op is that the producer's output indexing map associated to the result to be fused must be invertible (e.g. a simple permutation). Before this patch, only the first output indexing map was being checked; this bug produced issues when the operand to fuse was not the 1st result of the producer op. For example, this situation arises when the producer op has multiple results because it's the result of previous fusions where the original result had been preserved: in these cases, the pass ought to check the indexing map of the result being fused, which is not necessarily the 1st one. Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 9729a8c commit f8d832f

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
172172

173173
// Finally the index_map for the result must be invertible. For now just
174174
// verify it is a permutation.
175+
auto producerResult = cast<OpResult>(fusedOperand->get());
175176
AffineMap producerResultIndexMap =
176-
producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
177+
producer.getIndexingMapMatchingResult(producerResult);
177178
if (!producerResultIndexMap.isPermutation())
178179
return false;
179180

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,69 @@ module {
10141014
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
10151015
// CHECK: linalg.yield %[[T3]] : f32
10161016
// CHECK: return %[[GENERIC]]
1017+
1018+
// -----
1019+
1020+
// In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate.
1021+
// The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with
1022+
// an additional result and ana dditional output indexing map that is not a permutation / not invertible.
1023+
// The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map
1024+
// is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument.
1025+
// When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map.
1026+
1027+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1028+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)>
1029+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
1030+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
1031+
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>
1032+
module {
1033+
func.func @fuse_only_as_long_as_result_map_is_permutation(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> {
1034+
%c2 = arith.constant 2 : index
1035+
%c1 = arith.constant 1 : index
1036+
%cst = arith.constant 0.000000e+00 : f32
1037+
%c0 = arith.constant 0 : index
1038+
%0 = tensor.empty() : tensor<1x2x2x1xf32>
1039+
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%0 : tensor<1x2x2x1xf32>) {
1040+
^bb0(%out: f32):
1041+
%6 = linalg.index 1 : index
1042+
%7 = linalg.index 2 : index
1043+
%8 = arith.cmpi ult, %6, %c1 : index
1044+
%9 = arith.cmpi ult, %7, %c2 : index
1045+
%10 = arith.andi %8, %9 : i1
1046+
%11 = scf.if %10 -> (f32) {
1047+
%extracted = tensor.extract %arg1[%c0, %6, %7, %c0] : tensor<1x1x2x1xf32>
1048+
scf.yield %extracted : f32
1049+
} else {
1050+
scf.yield %cst : f32
1051+
}
1052+
linalg.yield %11 : f32
1053+
} -> tensor<1x2x2x1xf32>
1054+
%2 = tensor.empty() : tensor<4x1x1x1xf32>
1055+
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x2x2x1xf32>) outs(%2 : tensor<4x1x1x1xf32>) {
1056+
^bb0(%in: f32, %out: f32):
1057+
linalg.yield %in : f32
1058+
} -> tensor<4x1x1x1xf32>
1059+
%4 = tensor.empty() : tensor<1x1x2x4xf32>
1060+
%expanded = tensor.expand_shape %4 [[0], [1], [2], [3, 4, 5]] output_shape [1, 1, 2, 4, 1, 1] : tensor<1x1x2x4xf32> into tensor<1x1x2x4x1x1xf32>
1061+
%5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %3 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) outs(%expanded : tensor<1x1x2x4x1x1xf32>) {
1062+
^bb0(%in: f32, %in_0: f32, %out: f32):
1063+
%6 = arith.mulf %in, %in_0 : f32
1064+
%7 = arith.addf %6, %out : f32
1065+
linalg.yield %7 : f32
1066+
} -> tensor<1x1x2x4x1x1xf32>
1067+
%collapsed = tensor.collapse_shape %5 [[0], [1], [2], [3, 4, 5]] : tensor<1x1x2x4x1x1xf32> into tensor<1x1x2x4xf32>
1068+
return %collapsed : tensor<1x1x2x4xf32>
1069+
}
1070+
}
1071+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1072+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 4 + d1 * 2 + d2 + d3, 0, 0, 0)>
1073+
// CHECK: func.func @fuse_only_as_long_as_result_map_is_permutation
1074+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x2x1xf32>, %[[ARG1:.*]]: tensor<1x1x2x1xf32>) -> tensor<1x1x2x4xf32> {
1075+
// CHECK-DAG: %[[OUT0:.+]] = tensor.empty() : tensor<1x2x2x1xf32>
1076+
// CHECK-DAG: %[[OUT1:.+]] = tensor.empty() : tensor<4x1x1x1xf32>
1077+
// CHECK: %[[FUSED:.+]]:2 = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1078+
// CHECK-SAME: outs(%[[OUT0]], %[[OUT1]] : tensor<1x2x2x1xf32>, tensor<4x1x1x1xf32>)
1079+
// CHECK-NOT: linalg.generic
1080+
// CHECK: tensor.expand_shape
1081+
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
1082+
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)

0 commit comments

Comments
 (0)