Skip to content

Commit f41ef60

Browse files
committed
[mlir][linalg] Fix partial fuse by collapse (llvm#136326)
Similar to `FoldWithProducerReshapeOpByCollapsing`, `FoldReshapeWithGenericOpByCollapsing` needs to be able to handle partial fusion of a reshape by collapsing. This means that the source of the generated `expand_shape` op (aka the collapsed linalg op) might not match the type of the original `collapse_shape` op. This change instead replaces the original linalg op with the new `expand_shape` op which is guaranteed to be the same type. Signed-off-by: Ian Wood <[email protected]>
1 parent 4f387eb commit f41ef60

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,23 +1907,6 @@ struct FoldReshapeWithGenericOpByCollapsing
19071907
producer, "failed to do the fusion by collapsing transformation");
19081908
}
19091909

1910-
if (!collapseResult) {
1911-
return rewriter.notifyMatchFailure(reshapeOp,
1912-
"fusion by expansion failed");
1913-
}
1914-
1915-
// Find the replacement for the reshape op. Since the replacements have the
1916-
// same type as the returns of the original generic op, the consumer reshape
1917-
// op can be replaced by the source of the expand_shape op that defines
1918-
// the replacement.
1919-
Value reshapeReplacement =
1920-
(collapseResult
1921-
->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
1922-
if (auto expandOp =
1923-
reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
1924-
reshapeReplacement = expandOp.getSrc();
1925-
}
1926-
rewriter.replaceOp(reshapeOp, reshapeReplacement);
19271910
rewriter.replaceOp(producer, collapseResult->results);
19281911
return success();
19291912
}

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,31 @@ func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
830830
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
831831
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
832832
// CHECK: return %[[OUT]], %[[DIM]]
833+
834+
// -----
835+
836+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
837+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4, d1, d2)>
838+
func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: tensor<4x128x192x?x32xf32>) -> tensor<512x192x?xf32> {
839+
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x32x128x192xf16>) outs(%arg1 : tensor<4x128x192x?x32xf32>) {
840+
^bb0(%in: f16, %out: f32):
841+
linalg.yield %out : f32
842+
} -> tensor<4x128x192x?x32xf32>
843+
%collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3, 4]] : tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
844+
return %collapsed : tensor<512x192x?xf32>
845+
}
846+
// CHECK-LABEL: func @partial_fuse_by_collapsing
847+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x32x128x192xf16>
848+
// CHECK-SAME: %[[ARG1:.+]]: tensor<4x128x192x?x32xf32>
849+
// CHECK-DAG: %[[COLLAPSED0:.+]] = tensor.collapse_shape %[[ARG0]]
850+
// CHECK-SAME: tensor<4x?x32x128x192xf16> into tensor<4x?x128x192xf16>
851+
// CHECK-DAG: %[[COLLAPSED1:.+]] = tensor.collapse_shape %[[ARG1]]
852+
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<4x128x192x?xf32>
853+
// CHECK: %[[GENERIC:.+]] = linalg.generic
854+
// CHECK-SAME: ins(%[[COLLAPSED0]]
855+
// CHECK-SAME: outs(%[[COLLAPSED1]]
856+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[GENERIC]]
857+
// CHECK-SAME: tensor<4x128x192x?xf32> into tensor<4x128x192x?x32xf32>
858+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
859+
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
860+
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>

0 commit comments

Comments
 (0)