Skip to content

Commit 770eaa3

Browse files
committed
add tests for non const pad val
Signed-off-by: Max Dawkins <[email protected]>
1 parent f76287c commit 770eaa3

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x
594594

595595
// -----
596596

597+
func.func @no_fuse_by_collapsing_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> {
598+
%expand = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32>
599+
%cst = arith.constant 0 : i32
600+
%padded_0 = tensor.pad %expand low[1, 0, 0] high[5, 0, 0] {
601+
^bb0(%arg1: index, %arg2: index, %arg3: index):
602+
%pad_val = arith.index_cast %arg1 : index to i32
603+
tensor.yield %pad_val : i32
604+
} : tensor<2x3x4xi32> to tensor<8x3x4xi32>
605+
return %padded_0 : tensor<8x3x4xi32>
606+
}
607+
// CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding(
608+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>)
609+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
610+
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]]
611+
// CHECK: return %[[PAD]]
612+
613+
// -----
614+
597615
func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
598616
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
599617
%cst = arith.constant 0 : i32
@@ -678,6 +696,24 @@ func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf3
678696
// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
679697
// CHECK: return %[[PAD]]
680698

699+
// -----
700+
701+
func.func @collapse_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> {
702+
%cst = arith.constant 0 : i32
703+
%padded_0 = tensor.pad %arg0 low[1, 0, 0] high[5, 0, 0] {
704+
^bb0(%arg1: index, %arg2: index, %arg3: index):
705+
%pad_val = arith.index_cast %arg1 : index to i32
706+
tensor.yield %pad_val : i32
707+
} : tensor<2x3x4xi32> to tensor<8x3x4xi32>
708+
%collapsed = tensor.collapse_shape %padded_0 [[0], [1, 2]] : tensor<8x3x4xi32> into tensor<8x12xi32>
709+
return %collapsed : tensor<8x12xi32>
710+
}
711+
// CHECK: func @collapse_shape_with_producer_pad_non_constant_padding(
712+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>)
713+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
714+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]]
715+
// CHECK: return %[[COLLAPSED]]
716+
681717
// -----
682718
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
683719
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,23 @@ func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<
822822

823823
// -----
824824

825+
func.func @no_fuse_by_expanding_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> {
826+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xi32> into tensor<2x12xi32>
827+
%padded_0 = tensor.pad %collapse low[1, 0] high[5, 0] {
828+
^bb0(%arg1: index, %arg2: index):
829+
%pad_val = arith.index_cast %arg1 : index to i32
830+
tensor.yield %pad_val : i32
831+
} : tensor<2x12xi32> to tensor<8x12xi32>
832+
return %padded_0 : tensor<8x12xi32>
833+
}
834+
// CHECK: func @no_fuse_by_expanding_pad_non_constant_padding(
835+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>)
836+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
837+
// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
838+
// CHECK: return %[[PAD]]
839+
840+
// -----
841+
825842
func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
826843
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
827844
%cst = arith.constant 0 : i32
@@ -904,6 +921,23 @@ func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?xf32>,
904921

905922
// -----
906923

924+
func.func @expand_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> {
925+
%padded_0 = tensor.pad %arg0 low[1, 0] high[5, 0] {
926+
^bb0(%arg1: index, %arg2: index):
927+
%pad_val = arith.index_cast %arg1 : index to i32
928+
tensor.yield %pad_val : i32
929+
} : tensor<2x12xi32> to tensor<8x12xi32>
930+
%expand = tensor.expand_shape %padded_0 [[0], [1, 2]] output_shape [8, 3, 4] : tensor<8x12xi32> into tensor<8x3x4xi32>
931+
return %expand : tensor<8x3x4xi32>
932+
}
933+
// CHECK: func @expand_shape_with_producer_pad_non_constant_padding(
934+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>)
935+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
936+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]]
937+
// CHECK: return %[[EXPAND]]
938+
939+
// -----
940+
907941
func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
908942
%arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
909943
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)