@@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x
594594
595595// -----
596596
597+ func.func @no_fuse_by_collapsing_pad_non_constant_padding (%arg0 : tensor <2 x12 xi32 >) -> tensor <8 x3 x4 xi32 > {
598+ %expand = tensor.expand_shape %arg0 [[0 ], [1 , 2 ]] output_shape [2 , 3 , 4 ] : tensor <2 x12 xi32 > into tensor <2 x3 x4 xi32 >
599+ %cst = arith.constant 0 : i32
600+ %padded_0 = tensor.pad %expand low [1 , 0 , 0 ] high [5 , 0 , 0 ] {
601+ ^bb0 (%arg1: index , %arg2: index , %arg3: index ):
602+ %pad_val = arith.index_cast %arg1 : index to i32
603+ tensor.yield %pad_val : i32
604+ } : tensor <2 x3 x4 xi32 > to tensor <8 x3 x4 xi32 >
605+ return %padded_0 : tensor <8 x3 x4 xi32 >
606+ }
607+ // CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding(
608+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>)
609+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
610+ // CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]]
611+ // CHECK: return %[[PAD]]
612+
613+ // -----
614+
597615func.func @no_fuse_by_collapsing_pad (%arg0 : tensor <2 x12 x5 x336 x9 xi32 >) -> tensor <8 x5 x4 x17 x6 x7 x8 x14 xi32 > {
598616 %expand = tensor.expand_shape %arg0 [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] output_shape [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] : tensor <2 x12 x5 x336 x9 xi32 > into tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
599617 %cst = arith.constant 0 : i32
@@ -678,6 +696,24 @@ func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf3
678696// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
679697// CHECK: return %[[PAD]]
680698
699+ // -----
700+
701+ func.func @collapse_shape_with_producer_pad_non_constant_padding (%arg0 : tensor <2 x3 x4 xi32 >) -> tensor <8 x12 xi32 > {
702+ %cst = arith.constant 0 : i32
703+ %padded_0 = tensor.pad %arg0 low [1 , 0 , 0 ] high [5 , 0 , 0 ] {
704+ ^bb0 (%arg1: index , %arg2: index , %arg3: index ):
705+ %pad_val = arith.index_cast %arg1 : index to i32
706+ tensor.yield %pad_val : i32
707+ } : tensor <2 x3 x4 xi32 > to tensor <8 x3 x4 xi32 >
708+ %collapsed = tensor.collapse_shape %padded_0 [[0 ], [1 , 2 ]] : tensor <8 x3 x4 xi32 > into tensor <8 x12 xi32 >
709+ return %collapsed : tensor <8 x12 xi32 >
710+ }
711+ // CHECK: func @collapse_shape_with_producer_pad_non_constant_padding(
712+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>)
713+ // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
714+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]]
715+ // CHECK: return %[[COLLAPSED]]
716+
681717// -----
682718// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
683719#map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 )>
0 commit comments