Skip to content

Commit 27df12d

Browse files
nirvedhmeshramfrederik-h
authored andcommitted
[mlir][linalg] Add FoldReshapeWithGenericOpByCollapsing pattern (llvm#131029)
This pattern to bubble up collapse shapes was missing in `populateFoldReshapeOpsByCollapsingPatterns` . Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 5448042 commit 27df12d

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,7 @@ namespace {
18481848
class FoldWithProducerReshapeOpByCollapsing
18491849
: public OpRewritePattern<GenericOp> {
18501850
public:
1851+
// TODO : support fusion with all linalg ops, not just generic.
18511852
FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
18521853
ControlFusionFn foldReshapes,
18531854
PatternBenefit benefit = 1)
@@ -1887,6 +1888,81 @@ class FoldWithProducerReshapeOpByCollapsing
18871888
ControlFusionFn controlFoldingReshapes;
18881889
};
18891890

1891+
/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1892+
/// by expanding the dimensionality of the loop in the producer op.
1893+
struct FoldReshapeWithGenericOpByCollapsing
1894+
: public OpRewritePattern<tensor::CollapseShapeOp> {
1895+
1896+
FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1897+
ControlFusionFn foldReshapes,
1898+
PatternBenefit benefit = 1)
1899+
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1900+
controlFoldingReshapes(std::move(foldReshapes)) {}
1901+
1902+
LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1903+
PatternRewriter &rewriter) const override {
1904+
// Fold only if all constraints of fusing with reshape by collapsing are
1905+
// met.
1906+
auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1907+
if (!producerResult) {
1908+
return rewriter.notifyMatchFailure(reshapeOp,
1909+
"source not produced by an operation");
1910+
}
1911+
1912+
// TODO : support fusion with all linalg producers, not just generic.
1913+
auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1914+
if (!producer) {
1915+
return rewriter.notifyMatchFailure(reshapeOp,
1916+
"producer not a generic op");
1917+
}
1918+
1919+
SmallVector<ReassociationIndices> collapsableIterationDims =
1920+
getCollapsableIterationSpaceDims(
1921+
producer,
1922+
producer.getDpsInitOperand(producerResult.getResultNumber()),
1923+
reshapeOp.getReassociationIndices());
1924+
if (collapsableIterationDims.empty()) {
1925+
return rewriter.notifyMatchFailure(
1926+
reshapeOp, "failed preconditions of fusion with producer generic op");
1927+
}
1928+
1929+
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1930+
return rewriter.notifyMatchFailure(reshapeOp,
1931+
"fusion blocked by control function");
1932+
}
1933+
1934+
std::optional<CollapseResult> collapseResult =
1935+
collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1936+
if (!collapseResult) {
1937+
return rewriter.notifyMatchFailure(
1938+
producer, "failed to do the fusion by collapsing transformation");
1939+
}
1940+
1941+
if (!collapseResult) {
1942+
return rewriter.notifyMatchFailure(reshapeOp,
1943+
"fusion by expansion failed");
1944+
}
1945+
1946+
// Find the replacement for the reshape op. Since the replacements have the
1947+
// same type as the returns of the original generic op, the consumer reshape
1948+
// op can be replaced by the source of the expand_shape op that defines
1949+
// the replacement.
1950+
Value reshapeReplacement =
1951+
(collapseResult
1952+
->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
1953+
if (auto expandOp =
1954+
reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
1955+
reshapeReplacement = expandOp.getSrc();
1956+
}
1957+
rewriter.replaceOp(reshapeOp, reshapeReplacement);
1958+
rewriter.replaceOp(producer, collapseResult->results);
1959+
return success();
1960+
}
1961+
1962+
private:
1963+
ControlFusionFn controlFoldingReshapes;
1964+
};
1965+
18901966
class FoldPadWithProducerReshapeOpByCollapsing
18911967
: public OpRewritePattern<tensor::PadOp> {
18921968
public:
@@ -2215,6 +2291,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
22152291
controlFoldingReshapes);
22162292
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
22172293
patterns.getContext(), controlFoldingReshapes);
2294+
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2295+
controlFoldingReshapes);
22182296
}
22192297

22202298
void mlir::linalg::populateElementwiseOpsFusionPatterns(

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,163 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
638638
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
639639
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
640640
// CHECK: return %[[EXPAND]]
641+
642+
// -----
643+
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
644+
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
645+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
646+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
647+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
648+
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
649+
func.func @fuse_by_collapsing_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
650+
%arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>) {
651+
%init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
652+
%init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
653+
%generic:2 = linalg.generic {
654+
indexing_maps = [#map0, #map1, #map2, #map3, #map4],
655+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
656+
ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
657+
outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
658+
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
659+
%t0 = arith.addi %b0, %b1 : i32
660+
%t1 = arith.addi %t0, %b2 : i32
661+
linalg.yield %t1, %t1 : i32, i32
662+
} -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
663+
%collapse_1 = tensor.collapse_shape %generic#0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
664+
%collapse_2 = tensor.collapse_shape %generic#1 [[0, 1], [2], [3], [4], [5, 6, 7]] : tensor<3x4x2x9x5x6x7x8xi32> into tensor<12x2x9x5x336xi32>
665+
return %collapse_1, %collapse_2 : tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>
666+
}
667+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
668+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
669+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
670+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
671+
// CHECK: func @fuse_by_collapsing_bubblecollapse(
672+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
673+
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
674+
// CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
675+
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
676+
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
677+
// CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
678+
// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
679+
// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
680+
// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
681+
// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
682+
// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
683+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
684+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
685+
// CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
686+
// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
687+
// CHECK: return %[[COLLAPSED_OP]]#0, %[[COLLAPSED_OP]]#1
688+
689+
// -----
690+
691+
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
692+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
693+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
694+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
695+
func.func @fuse_by_collapsing_indexing_op_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
696+
%arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x12x5x336x9xi32> {
697+
%init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
698+
%generic = linalg.generic {
699+
indexing_maps = [#map0, #map1, #map2, #map3],
700+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
701+
ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
702+
outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
703+
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
704+
%iv0 = linalg.index 0: index
705+
%iv1 = linalg.index 1: index
706+
%t0 = arith.addi %iv0, %iv1 : index
707+
%iv2 = linalg.index 2 : index
708+
%t1 = arith.addi %t0, %iv2 : index
709+
%iv3 = linalg.index 3 : index
710+
%t2 = arith.addi %t1, %iv3 : index
711+
%iv4 = linalg.index 4 : index
712+
%t3 = arith.addi %t2, %iv4 : index
713+
%iv5 = linalg.index 5 : index
714+
%t4 = arith.addi %t3, %iv5 : index
715+
%iv6 = linalg.index 6 : index
716+
%t5 = arith.addi %t4, %iv6 : index
717+
%iv7 = linalg.index 7 : index
718+
%t6 = arith.addi %t5, %iv7 : index
719+
%yield = arith.index_cast %t6 : index to i32
720+
linalg.yield %yield : i32
721+
} -> tensor<2x3x4x5x6x7x8x9xi32>
722+
%collapse = tensor.collapse_shape %generic [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
723+
return %collapse : tensor<2x12x5x336x9xi32>
724+
}
725+
// CHECK-LABEL: func @fuse_by_collapsing_indexing_op_bubblecollapse(
726+
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
727+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
728+
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
729+
// CHECK: %[[IV0:.+]] = linalg.index 0
730+
// CHECK: %[[IV1:.+]] = linalg.index 1
731+
// CHECK: %[[REM_IV1:.+]] = arith.remsi %[[IV1]], %[[C4]]
732+
// CHECK: %[[DIV_IV1:.+]] = arith.divsi %[[IV1]], %[[C4]]
733+
// CHECK: %[[IV2:.+]] = linalg.index 2
734+
// CHECK: %[[IV3:.+]] = linalg.index 3
735+
// CHECK: %[[REM1_IV3:.+]] = arith.remsi %[[IV3]], %[[C8]]
736+
// CHECK: %[[DIV1_IV3:.+]] = arith.divsi %[[IV3]], %[[C8]]
737+
// CHECK: %[[REM2_IV3:.+]] = arith.remsi %[[DIV1_IV3]], %[[C7]]
738+
// CHECK: %[[DIV2_IV3:.+]] = arith.divsi %[[DIV1_IV3]], %[[C7]]
739+
// CHECK: %[[IV4:.+]] = linalg.index 4
740+
// CHECK: %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
741+
// CHECK: %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
742+
// CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
743+
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
744+
// CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
745+
// CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
746+
// CHECK: %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
747+
// CHECK: %[[YIELD:.+]] = arith.index_cast %[[T6]]
748+
// CHECK: linalg.yield %[[YIELD]]
749+
750+
// -----
751+
752+
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)>
753+
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)>
754+
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
755+
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
756+
func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse(%arg0 : tensor<9x7x8x2x3x4x5x6xi32>,
757+
%arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x60x6x56x9xi32> {
758+
%init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
759+
%generic = linalg.generic {
760+
indexing_maps = [#map0, #map1, #map2, #map3],
761+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
762+
ins(%arg0, %arg1, %arg2 : tensor<9x7x8x2x3x4x5x6xi32>, tensor<7x8x2xi32>, tensor<6x3x4x5xi32>)
763+
outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
764+
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
765+
%t0 = arith.addi %b0, %b1 : i32
766+
%t1 = arith.addi %t0, %b2 : i32
767+
linalg.yield %t1 : i32
768+
} -> tensor<2x3x4x5x6x7x8x9xi32>
769+
%collapse = tensor.collapse_shape %generic [[0], [1, 2, 3], [4], [5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x60x6x56x9xi32>
770+
return %collapse : tensor<2x60x6x56x9xi32>
771+
}
772+
773+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
774+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
775+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
776+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
777+
// CHECK: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
778+
// CHECK-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
779+
// CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
780+
// CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
781+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
782+
// CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
783+
// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
784+
// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
785+
// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
786+
// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
787+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
788+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
789+
// CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
790+
// CHECK-SAME: outs(%[[INIT_RESHAPE]] :
791+
// CHECK: return %[[COLLAPSED_OP]]
792+
793+
// CONTROL: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
794+
// CONTROL-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
795+
// CONTROL-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
796+
// CONTROL-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
797+
// CONTROL: %[[GENERIC:.+]] = linalg.generic
798+
// CONTROL-SAME: ins(%[[ARG0]],
799+
// CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
800+
// CONTROL: return %[[COLLAPSE]]

mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ struct TestLinalgElementwiseFusion
235235
// Skip fusing the first operand.
236236
return fusedOperand->getOperandNumber();
237237
}
238+
Operation *consumer = fusedOperand->getOwner();
239+
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(consumer)) {
240+
auto producerResult = dyn_cast<OpResult>(collapseOp.getSrc());
241+
// skip fusing first result.
242+
return producerResult.getResultNumber();
243+
}
238244
return true;
239245
};
240246
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);

0 commit comments

Comments
 (0)