Skip to content

Commit 620f2a8

Browse files
[mlir][linalg] Retain named ops in fuseWithReshapeByExpansion pattern
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 9a32af2 commit 620f2a8

File tree

2 files changed

+75
-22
lines changed

2 files changed

+75
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -927,17 +927,43 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
927927
iteratorTypes[j] = type;
928928

929929
TypeRange resultTypes = ValueRange(outputs).getTypes();
930-
auto fusedOp =
931-
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
932-
/*inputs=*/expandedOpOperands, outputs,
933-
expandedOpIndexingMaps, iteratorTypes);
934-
Region &fusedRegion = fusedOp->getRegion(0);
935-
Region &originalRegion = linalgOp->getRegion(0);
936-
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
937-
938-
// Update the index accesses after the expansion.
939-
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
940-
930+
Operation *fusedOp;
931+
932+
TypeSwitch<Operation *>(linalgOp.getOperation())
933+
.Case<GenericOp>([&](GenericOp op) {
934+
fusedOp = rewriter.create<GenericOp>(
935+
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
936+
expandedOpIndexingMaps, iteratorTypes);
937+
Region &fusedRegion = fusedOp->getRegion(0);
938+
Region &originalRegion = linalgOp->getRegion(0);
939+
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
940+
fusedRegion.begin());
941+
942+
// Update the index accesses after the expansion.
943+
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
944+
expansionInfo);
945+
})
946+
.Case<TransposeOp>([&](TransposeOp op) {
947+
SmallVector<ReassociationIndices> reassociation =
948+
isExpanding ? expandingReshapeOp.getReassociationIndices()
949+
: collapsingReshapeOp.getReassociationIndices();
950+
applyPermutationToVector(reassociation, op.getPermutation());
951+
SmallVector<int64_t> newPerm;
952+
for (auto reassoc : reassociation) {
953+
for (auto dim : reassoc) {
954+
newPerm.push_back(dim);
955+
}
956+
}
957+
fusedOp = rewriter.create<TransposeOp>(
958+
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
959+
})
960+
// All other expandable linalg ops that are not generic or transpose can
961+
// be cloned with the expanded input and output operands.
962+
.Default([&](Operation *op) {
963+
fusedOp = clone(
964+
rewriter, linalgOp, resultTypes,
965+
llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
966+
});
941967
// Reshape the result values to their original shape if this is a collapsing
942968
// reshape folded into its consumer.
943969
SmallVector<Value> resultVals;

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
753753
return %1 : tensor<?x?x4x5xf32>
754754
}
755755

756-
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
757756
// CHECK: func @linalg_add_reshape_consumer_fusion
758757
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
759758
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
774773
// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
775774
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
776775
// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
777-
// CHECK: %[[T4:.+]] = linalg.generic
778-
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
779-
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
776+
// CHECK: %[[T4:.+]] = linalg.add
780777
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
781778
// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>)
782779
// CHECK: return %[[T4]] : tensor<?x?x4x5xf32>
783780

784781
// -----
785782

786-
#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
787-
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
788-
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
789783
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
790784
%arg1 : tensor<?x?xf32>,
791785
%arg2 : tensor<?x?xf32>) ->
@@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
798792
return %1 : tensor<?x?xf32>
799793
}
800794

801-
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
802795
// CHECK: func @linalg_add_reshape_producer_fusion
803796
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
804797
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -817,16 +810,50 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
817810
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
818811
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
819812
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
820-
// CHECK: %[[T3:.+]] = linalg.generic
821-
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
822-
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
813+
// CHECK: %[[T3:.+]] = linalg.add
823814
// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
824815
// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>)
825816
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
826817
// CHECK-SAME: [0, 1], [2, 3]
827818
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
828819
// CHECK: return %[[T4]]
829820

821+
// -----
822+
823+
func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
824+
%arg1 : tensor<?x?xf32>) ->
825+
tensor<?x?xf32>
826+
{
827+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
828+
tensor<?x7x?x8xf32> into tensor<?x?xf32>
829+
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
830+
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
831+
return %1 : tensor<?x?xf32>
832+
}
833+
834+
// CHECK: func @linalg_transpose_reshape_producer_fusion
835+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
836+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
837+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
838+
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
839+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
840+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
841+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
842+
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
843+
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
844+
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
845+
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
846+
// CHECK: %[[T2:.+]] = linalg.transpose
847+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
848+
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
849+
// CHECK-SAME: permutation = [2, 3, 0, 1]
850+
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
851+
// CHECK-SAME: [0, 1], [2, 3]
852+
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
853+
// CHECK: return %[[T3]]
854+
855+
856+
830857
// -----
831858

832859
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {

0 commit comments

Comments
 (0)