@@ -638,3 +638,163 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
638638// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
639639// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
640640// CHECK: return %[[EXPAND]]
641+
642+ // -----
643+ // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
644+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
645+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 )>
646+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d3 , d4 , d5 , d6 )>
647+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
648+ #map4 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d1 , d2 , d0 , d7 , d3 , d4 , d5 , d6 )>
649+ func.func @fuse_by_collapsing_bubblecollapse (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >,
650+ %arg1 : tensor <2 x3 x4 xi32 >, %arg2 : tensor <5 x6 x7 x8 xi32 >) -> (tensor <2 x12 x5 x336 x9 xi32 >, tensor <12 x2 x9 x5 x336 xi32 >) {
651+ %init_0 = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
652+ %init_1 = tensor.empty () : tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >
653+ %generic:2 = linalg.generic {
654+ indexing_maps = [#map0 , #map1 , #map2 , #map3 , #map4 ],
655+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
656+ ins (%arg0 , %arg1 , %arg2 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <2 x3 x4 xi32 >, tensor <5 x6 x7 x8 xi32 >)
657+ outs (%init_0 , %init_1 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >) {
658+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 , %b4 : i32 ):
659+ %t0 = arith.addi %b0 , %b1 : i32
660+ %t1 = arith.addi %t0 , %b2 : i32
661+ linalg.yield %t1 , %t1 : i32 , i32
662+ } -> (tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 >)
663+ %collapse_1 = tensor.collapse_shape %generic#0 [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x12 x5 x336 x9 xi32 >
664+ %collapse_2 = tensor.collapse_shape %generic#1 [[0 , 1 ], [2 ], [3 ], [4 ], [5 , 6 , 7 ]] : tensor <3 x4 x2 x9 x5 x6 x7 x8 xi32 > into tensor <12 x2 x9 x5 x336 xi32 >
665+ return %collapse_1 , %collapse_2 : tensor <2 x12 x5 x336 x9 xi32 >, tensor <12 x2 x9 x5 x336 xi32 >
666+ }
667+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
668+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
669+ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
670+ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
671+ // CHECK: func @fuse_by_collapsing_bubblecollapse(
672+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
673+ // CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
674+ // CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
675+ // CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
676+ // CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
677+ // CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
678+ // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
679+ // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
680+ // CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
681+ // CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
682+ // CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
683+ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
684+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
685+ // CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
686+ // CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
687+ // CHECK: return %[[COLLAPSED_OP]]#0, %[[COLLAPSED_OP]]#1
688+
689+ // -----
690+
691+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
692+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 )>
693+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d3 , d4 , d5 , d6 )>
694+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
695+ func.func @fuse_by_collapsing_indexing_op_bubblecollapse (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >,
696+ %arg1 : tensor <2 x3 x4 xi32 >, %arg2 : tensor <5 x6 x7 x8 xi32 >) -> tensor <2 x12 x5 x336 x9 xi32 > {
697+ %init = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
698+ %generic = linalg.generic {
699+ indexing_maps = [#map0 , #map1 , #map2 , #map3 ],
700+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
701+ ins (%arg0 , %arg1 , %arg2 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >, tensor <2 x3 x4 xi32 >, tensor <5 x6 x7 x8 xi32 >)
702+ outs (%init : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >) {
703+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 ):
704+ %iv0 = linalg.index 0 : index
705+ %iv1 = linalg.index 1 : index
706+ %t0 = arith.addi %iv0 , %iv1 : index
707+ %iv2 = linalg.index 2 : index
708+ %t1 = arith.addi %t0 , %iv2 : index
709+ %iv3 = linalg.index 3 : index
710+ %t2 = arith.addi %t1 , %iv3 : index
711+ %iv4 = linalg.index 4 : index
712+ %t3 = arith.addi %t2 , %iv4 : index
713+ %iv5 = linalg.index 5 : index
714+ %t4 = arith.addi %t3 , %iv5 : index
715+ %iv6 = linalg.index 6 : index
716+ %t5 = arith.addi %t4 , %iv6 : index
717+ %iv7 = linalg.index 7 : index
718+ %t6 = arith.addi %t5 , %iv7 : index
719+ %yield = arith.index_cast %t6 : index to i32
720+ linalg.yield %yield : i32
721+ } -> tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
722+ %collapse = tensor.collapse_shape %generic [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x12 x5 x336 x9 xi32 >
723+ return %collapse : tensor <2 x12 x5 x336 x9 xi32 >
724+ }
725+ // CHECK-LABEL: func @fuse_by_collapsing_indexing_op_bubblecollapse(
726+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
727+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
728+ // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
729+ // CHECK: %[[IV0:.+]] = linalg.index 0
730+ // CHECK: %[[IV1:.+]] = linalg.index 1
731+ // CHECK: %[[REM_IV1:.+]] = arith.remsi %[[IV1]], %[[C4]]
732+ // CHECK: %[[DIV_IV1:.+]] = arith.divsi %[[IV1]], %[[C4]]
733+ // CHECK: %[[IV2:.+]] = linalg.index 2
734+ // CHECK: %[[IV3:.+]] = linalg.index 3
735+ // CHECK: %[[REM1_IV3:.+]] = arith.remsi %[[IV3]], %[[C8]]
736+ // CHECK: %[[DIV1_IV3:.+]] = arith.divsi %[[IV3]], %[[C8]]
737+ // CHECK: %[[REM2_IV3:.+]] = arith.remsi %[[DIV1_IV3]], %[[C7]]
738+ // CHECK: %[[DIV2_IV3:.+]] = arith.divsi %[[DIV1_IV3]], %[[C7]]
739+ // CHECK: %[[IV4:.+]] = linalg.index 4
740+ // CHECK: %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
741+ // CHECK: %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
742+ // CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
743+ // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
744+ // CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
745+ // CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
746+ // CHECK: %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
747+ // CHECK: %[[YIELD:.+]] = arith.index_cast %[[T6]]
748+ // CHECK: linalg.yield %[[YIELD]]
749+
750+ // -----
751+
752+ #map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d7 , d5 , d6 , d0 , d1 , d2 , d3 , d4 )>
753+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d5 , d6 , d0 )>
754+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d4 , d1 , d2 , d3 )>
755+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
756+ func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse (%arg0 : tensor <9 x7 x8 x2 x3 x4 x5 x6 xi32 >,
757+ %arg1 : tensor <7 x8 x2 xi32 >, %arg2 : tensor <6 x3 x4 x5 xi32 >) -> tensor <2 x60 x6 x56 x9 xi32 > {
758+ %init = tensor.empty () : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
759+ %generic = linalg.generic {
760+ indexing_maps = [#map0 , #map1 , #map2 , #map3 ],
761+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]}
762+ ins (%arg0 , %arg1 , %arg2 : tensor <9 x7 x8 x2 x3 x4 x5 x6 xi32 >, tensor <7 x8 x2 xi32 >, tensor <6 x3 x4 x5 xi32 >)
763+ outs (%init : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >) {
764+ ^bb0 (%b0 : i32 , %b1 : i32 , %b2 : i32 , %b3 : i32 ):
765+ %t0 = arith.addi %b0 , %b1 : i32
766+ %t1 = arith.addi %t0 , %b2 : i32
767+ linalg.yield %t1 : i32
768+ } -> tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
769+ %collapse = tensor.collapse_shape %generic [[0 ], [1 , 2 , 3 ], [4 ], [5 , 6 ], [7 ]] : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > into tensor <2 x60 x6 x56 x9 xi32 >
770+ return %collapse : tensor <2 x60 x6 x56 x9 xi32 >
771+ }
772+
773+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
774+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
775+ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
776+ // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
777+ // CHECK: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
778+ // CHECK-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
779+ // CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
780+ // CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
781+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
782+ // CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
783+ // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
784+ // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
785+ // CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
786+ // CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
787+ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
788+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
789+ // CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
790+ // CHECK-SAME: outs(%[[INIT_RESHAPE]] :
791+ // CHECK: return %[[COLLAPSED_OP]]
792+
793+ // CONTROL: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
794+ // CONTROL-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
795+ // CONTROL-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
796+ // CONTROL-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
797+ // CONTROL: %[[GENERIC:.+]] = linalg.generic
798+ // CONTROL-SAME: ins(%[[ARG0]],
799+ // CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
800+ // CONTROL: return %[[COLLAPSE]]
0 commit comments