@@ -638,3 +638,163 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
638
638
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
639
639
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
640
640
// CHECK: return %[[EXPAND]]
641
+
642
+ // -----
643
+ // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
644
+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
645
+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 )>
646
+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d3 , d4 , d5 , d6 )>
647
+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
648
+ #map4 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d1 , d2 , d0 , d7 , d3 , d4 , d5 , d6 )>
649
+ func.func @fuse_by_collapsing_bubblecollapse (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >,
650
+ %arg1 : tensor <2 x3 x4 xi32 >, %arg2 : tensor <5 x6 x7 x8 xi32 >) -> (tensor <2 x12 x5 x336 x9 xi32 >, tensor <12 x2 x9 x5 x336 xi32 >) {
651
+ %init_0 = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
652
+ %init_1 = tensor.empty () : tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >
653
+ %generic:2 = linalg.generic {
654
+ indexing_maps = [#map0 , #map1 , #map2 , #map3 , #map4 ],
655
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
656
+ ins (%arg0 , %arg1 , %arg2 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <2 x3 x4 xi32 >, tensor <5 x6 x7 x8 xi32 >)
657
+ outs (%init_0 , %init_1 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >) {
658
+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 , %b4 : i32 ):
659
+ %t0 = arith.addi %b0 , %b1 : i32
660
+ %t1 = arith.addi %t0 , %b2 : i32
661
+ linalg.yield %t1 , %t1 : i32 , i32
662
+ } -> (tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >)
663
+ %collapse_1 = tensor.collapse_shape %generic#0 [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x12 x5 x336 x9 xi32 >
664
+ %collapse_2 = tensor.collapse_shape %generic#1 [[0 , 1 ], [2 ], [3 ], [4 ], [5 , 6 , 7 ]] : tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 > into tensor <12 x2 x9 x5 x336 xi32 >
665
+ return %collapse_1 , %collapse_2 : tensor <2 x12 x5 x336 x9 xi32 >, tensor <12 x2 x9 x5 x336 xi32 >
666
+ }
667
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
668
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
669
+ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
670
+ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
671
+ // CHECK: func @fuse_by_collapsing_bubblecollapse(
672
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
673
+ // CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
674
+ // CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
675
+ // CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
676
+ // CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
677
+ // CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
678
+ // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
679
+ // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
680
+ // CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
681
+ // CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
682
+ // CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
683
+ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
684
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
685
+ // CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
686
+ // CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
687
+ // CHECK: return %[[COLLAPSED_OP]]#0, %[[COLLAPSED_OP]]#1
688
+
689
+ // -----
690
+
691
+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
692
+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 )>
693
+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d3 , d4 , d5 , d6 )>
694
+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
695
+ func.func @fuse_by_collapsing_indexing_op_bubblecollapse (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >,
696
+ %arg1 : tensor <2 x3 x4 xi32 >, %arg2 : tensor <5 x6 x7 x8 xi32 >) -> tensor <2 x12 x5 x336 x9 xi32 > {
697
+ %init = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
698
+ %generic = linalg.generic {
699
+ indexing_maps = [#map0 , #map1 , #map2 , #map3 ],
700
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
701
+ ins (%arg0 , %arg1 , %arg2 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <2 x3 x4 xi32 >, tensor <5 x6 x7 x8 xi32 >)
702
+ outs (%init : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >) {
703
+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 ):
704
+ %iv0 = linalg.index 0 : index
705
+ %iv1 = linalg.index 1 : index
706
+ %t0 = arith.addi %iv0 , %iv1 : index
707
+ %iv2 = linalg.index 2 : index
708
+ %t1 = arith.addi %t0 , %iv2 : index
709
+ %iv3 = linalg.index 3 : index
710
+ %t2 = arith.addi %t1 , %iv3 : index
711
+ %iv4 = linalg.index 4 : index
712
+ %t3 = arith.addi %t2 , %iv4 : index
713
+ %iv5 = linalg.index 5 : index
714
+ %t4 = arith.addi %t3 , %iv5 : index
715
+ %iv6 = linalg.index 6 : index
716
+ %t5 = arith.addi %t4 , %iv6 : index
717
+ %iv7 = linalg.index 7 : index
718
+ %t6 = arith.addi %t5 , %iv7 : index
719
+ %yield = arith.index_cast %t6 : index to i32
720
+ linalg.yield %yield : i32
721
+ } -> tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
722
+ %collapse = tensor.collapse_shape %generic [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x12 x5 x336 x9 xi32 >
723
+ return %collapse : tensor <2 x12 x5 x336 x9 xi32 >
724
+ }
725
+ // CHECK-LABEL: func @fuse_by_collapsing_indexing_op_bubblecollapse(
726
+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
727
+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
728
+ // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
729
+ // CHECK: %[[IV0:.+]] = linalg.index 0
730
+ // CHECK: %[[IV1:.+]] = linalg.index 1
731
+ // CHECK: %[[REM_IV1:.+]] = arith.remsi %[[IV1]], %[[C4]]
732
+ // CHECK: %[[DIV_IV1:.+]] = arith.divsi %[[IV1]], %[[C4]]
733
+ // CHECK: %[[IV2:.+]] = linalg.index 2
734
+ // CHECK: %[[IV3:.+]] = linalg.index 3
735
+ // CHECK: %[[REM1_IV3:.+]] = arith.remsi %[[IV3]], %[[C8]]
736
+ // CHECK: %[[DIV1_IV3:.+]] = arith.divsi %[[IV3]], %[[C8]]
737
+ // CHECK: %[[REM2_IV3:.+]] = arith.remsi %[[DIV1_IV3]], %[[C7]]
738
+ // CHECK: %[[DIV2_IV3:.+]] = arith.divsi %[[DIV1_IV3]], %[[C7]]
739
+ // CHECK: %[[IV4:.+]] = linalg.index 4
740
+ // CHECK: %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
741
+ // CHECK: %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
742
+ // CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
743
+ // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
744
+ // CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
745
+ // CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
746
+ // CHECK: %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
747
+ // CHECK: %[[YIELD:.+]] = arith.index_cast %[[T6]]
748
+ // CHECK: linalg.yield %[[YIELD]]
749
+
750
+ // -----
751
+
752
+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d7 , d5 , d6 , d0 , d1 , d2 , d3 , d4 )>
753
+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d5 , d6 , d0 )>
754
+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d4 , d1 , d2 , d3 )>
755
+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
756
+ func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse (%arg0 : tensor <9 x7 x8 x2 x3 x4 x5 x6 xi32 >,
757
+ %arg1 : tensor <7 x8 x2 xi32 >, %arg2 : tensor <6 x3 x4 x5 xi32 >) -> tensor <2 x60 x6 x56 x9 xi32 > {
758
+ %init = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
759
+ %generic = linalg.generic {
760
+ indexing_maps = [#map0 , #map1 , #map2 , #map3 ],
761
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
762
+ ins (%arg0 , %arg1 , %arg2 : tensor <9 x7 x8 x2 x3 x4 x5 x6 xi32 >, tensor <7 x8 x2 xi32 >, tensor <6 x3 x4 x5 xi32 >)
763
+ outs (%init : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >) {
764
+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 ):
765
+ %t0 = arith.addi %b0 , %b1 : i32
766
+ %t1 = arith.addi %t0 , %b2 : i32
767
+ linalg.yield %t1 : i32
768
+ } -> tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
769
+ %collapse = tensor.collapse_shape %generic [[0 ], [1 , 2 , 3 ], [4 ], [5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x60 x6 x56 x9 xi32 >
770
+ return %collapse : tensor <2 x60 x6 x56 x9 xi32 >
771
+ }
772
+
773
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
774
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
775
+ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
776
+ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
777
+ // CHECK: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
778
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
779
+ // CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
780
+ // CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
781
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
782
+ // CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
783
+ // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
784
+ // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
785
+ // CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
786
+ // CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
787
+ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
788
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
789
+ // CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
790
+ // CHECK-SAME: outs(%[[INIT_RESHAPE]] :
791
+ // CHECK: return %[[COLLAPSED_OP]]
792
+
793
+ // CONTROL: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
794
+ // CONTROL-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
795
+ // CONTROL-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
796
+ // CONTROL-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
797
+ // CONTROL: %[[GENERIC:.+]] = linalg.generic
798
+ // CONTROL-SAME: ins(%[[ARG0]],
799
+ // CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
800
+ // CONTROL: return %[[COLLAPSE]]
0 commit comments